Update reducer to latest design (that runs at 159 MHz)
[matthijs/master-project/cλash.git] / reducer.hs
1 {-# LANGUAGE TypeOperators, TemplateHaskell, FlexibleContexts, TypeFamilies, 
2              ScopedTypeVariables, RecordWildCards #-}
3 module Reducer where
4
5 import qualified Prelude as P
6 import CLasH.HardwareTypes hiding ((>>))
7 import CLasH.Translator.Annotations
8
9 type Signed   = SizedInt
10 type Unsigned = SizedWord
11 type Index    = RangedWord
12
13 -- =======================================
14 -- = System size configuration variables =
15 -- =======================================
16 type DataSize       = D64
17 type IndexSize      = D16
18 type DiscrSize      = D7
19 type AdderDepth     = D12
20
21 -- Derived configuration variables
22 type DiscrRange     = (Pow2 DiscrSize) :-: D1
23 type AdderDepthPL   = AdderDepth :+: D3
24
25 -- =================
26 -- = Type Aliasses =
27 -- =================
28 type Shift          = Index D2
29 type DataInt        = Signed DataSize
30 type ArrayIndex     = Unsigned IndexSize
31 type Discr          = Index DiscrRange
32 type OutputSignal   = ((DataInt, ArrayIndex), Bool)
33
34 -- =================================
35 -- = Cell Definition and Accessors =
36 -- =================================
37 type CellType       = Bool                     
38 type Cell           = (CellType, (DataInt, Discr))
39
40 valid :: Cell -> Bool
41 valid (x, _) = x
42
43 value :: Cell -> DataInt
44 value (_, (v, _)) = v
45
46 discr :: Cell -> Discr
47 discr (_, (_, d)) = d
48
49 notValid :: Cell
50 notValid = (False, (0, 0))
51
52 -- ====================
53 -- = Helper functions =
54 -- ====================
55 v << e = shiftr v e
56 e >> v = shiftl v e
57
58 -- =======================
59 -- = Reducer State types =
60 -- =======================
61 data DiscrRecord =
62   DiscrR { prev_index  ::  ArrayIndex
63          , cur_discr   ::  SizedWord DiscrSize
64          }
65 type DiscrState = State DiscrRecord
66                             
67 type RippleState = 
68   State (Vector (AdderDepthPL :+: D1) (CellType, Discr))
69
70 data BlockRecord = 
71   Block { ptrs    ::  (SizedWord D4, SizedWord D4, SizedWord D4)
72         , buf1    ::  MemState AdderDepthPL DataInt
73         , buf2    ::  MemState AdderDepthPL DataInt
74         }
75 type BlockState = State BlockRecord
76
77 type FpAdderState = 
78   State ( ( (DataInt, DataInt, DataInt, DataInt)  -- Buffer input and double buffer output of the FPAdder
79           , Vector AdderDepthPL (CellType, Discr) -- Validbits & discriminators of values in the FP pipeline
80           )
81         , FpPlaceholder
82         )
83
84 type FpPlaceholder = 
85   State (Vector AdderDepth DataInt)
86
87 data OutputRecord = 
88   Outp { valid_mem :: RAM DiscrRange CellType
89        , mem1      :: MemState DiscrRange DataInt
90        , mem2      :: MemState DiscrRange DataInt
91        , lutm      :: MemState DiscrRange ArrayIndex
92        }                            
93 type OutputState = State OutputRecord
94
95 data ReducerRecord = 
96   Reducer { discrState  :: DiscrState
97           , rippleState :: RippleState
98           , blockState  :: BlockState
99           , pipeState   :: FpAdderState
100           , resultState :: OutputState          
101           , pipeline    :: ( Vector AdderDepth (Discr, ArrayIndex, Bool) -- Buffer link between discriminator and Result buffer                                                 
102                            , CellType                                    -- Buffer Valid bit of the resultbuffer at T+1                                                 
103                            , Vector D2 (DataInt, ArrayIndex)             -- Buffer Input (to encourage retiming)
104                            , Vector D2 OutputSignal                      -- Buffer Output (to encourage retiming)
105                            )
106           }                              
107 type ReducerState   = State ReducerRecord
108
109 -- ===========================================================
110 -- = Discrimintor: Hands out new discriminator to the system =
111 -- ===========================================================
112 {-# ANN discriminator (InitState 'initDiscrState) #-}
113 discriminator :: 
114   DiscrState -> 
115   ArrayIndex -> 
116   (DiscrState, Discr, Bool)
117 discriminator (State (DiscrR {..})) index = ( State DiscrR { prev_index = index
118                                                            , cur_discr  = cur_discr'
119                                                            }
120                                             , discr
121                                             , new_discr
122                                             )
123   where
124     new_discr               = index /= prev_index
125     cur_discr'  | new_discr = cur_discr + 1
126                 | otherwise = cur_discr
127     discr                   = fromSizedWord cur_discr'
128
129 -- ======================================================
130 -- = Input Buffer: Buffers incomming inputs when needed =
131 -- ======================================================
132 {-# ANN rippleBuffer (InitState 'initRippleState) #-}
133 rippleBuffer :: 
134   RippleState ->
135   (Discr, Shift) ->
136   (RippleState, (CellType, Discr), (CellType, Discr))
137 rippleBuffer (State buf) (inp, shift) = (State buf', out1, out2)
138   where
139     -- Write value
140     next_valids              = (map fst buf) << True
141     buf''                    = zipWith selects buf next_valids
142     selects cell next_valid  = if (not (fst cell)) && next_valid then
143                                  (True, inp)
144                                else
145                                  cell
146     -- Shift values                            
147     buf'        | shift == 2 = (False, 0) >> ((False, 0) >> buf'')
148                 | shift == 1 = (False, 0) >> buf''
149                 | otherwise  = buf''
150     -- Read values
151     out1                     = last buf
152     out2                     = last (init buf)
153     
154 {-# ANN blockBuffer (InitState 'initBlockState) #-}
155 blockBuffer :: 
156   BlockState ->
157   (DataInt, Shift) ->
158   (BlockState, DataInt, DataInt)
159 blockBuffer (State (Block {..})) (inp, shift) = ( State Block { ptrs = ptrs'
160                                                               , buf1 = buf1'
161                                                               , buf2 = buf2'
162                                                               }
163                                                 , out1, out2)
164   where 
165     -- Do some state (un)packing
166     (rd_ptr1, rd_ptr2, wr_ptr) = ptrs
167     ptrs'                      = (rd_ptr1', rd_ptr2', wr_ptr')
168     -- Update pointers               
169     count                      = fromRangedWord shift
170     (rd_ptr1', rd_ptr2')       = (rd_ptr1 + count, rd_ptr2 + count)
171     wr_ptr'                    = wr_ptr + 1
172     -- Write & Read from RAMs
173     (buf1', out1)              = blockRAM buf1 inp (fromSizedWord rd_ptr1) (fromSizedWord wr_ptr) True
174     (buf2', out2)              = blockRAM buf2 inp (fromSizedWord rd_ptr2) (fromSizedWord wr_ptr) True
175     
176 -- ============================================
177 -- = Simulated pipelined floating point adder =
178 -- ============================================
179 {-# ANN fpAdder (InitState 'initPipeState) #-}
180 fpAdder :: 
181   FpAdderState -> 
182   (Cell, Cell) -> 
183   (FpAdderState, (Cell, Cell))         
184 fpAdder (State ((buffer, pipe), adderState)) (arg1, arg2) = (State ((buffer', pipe'), adderState'), (pipeT_1, pipeT))
185   where
186     -- Do some state (un)packing
187     (a1,a2,dataT_1,dataT)   = buffer
188     buffer'                 = (value arg1, value arg2, adderOut, dataT_1)
189     -- placeholder adder
190     (adderState', adderOut) = fpPlaceholder adderState (a1, a2)
191     -- Save corresponding indexes and valid bits      
192     pipe'                   = (valid arg1, discr arg1) >> pipe
193     -- Produce output for time T and T+1
194     pipeEndT                = last pipe
195     pipeEndT_1              = last (init pipe)
196     pipeT                   = (fst pipeEndT, (dataT, snd pipeEndT))
197     pipeT_1                 = (fst pipeEndT_1,(dataT_1,snd pipeEndT_1))
198
199 {-# ANN fpPlaceholder (InitState 'initAdderState) #-}
200 fpPlaceholder :: FpPlaceholder -> (DataInt, DataInt) -> (FpPlaceholder, DataInt)
201 fpPlaceholder (State pipe) (arg1, arg2) = (State pipe', pipe_out)
202   where
203     pipe'       = (arg1 + arg2) +> init pipe
204     pipe_out    = last pipe
205
206 -- ===================================================
207 -- = Optimized Partial Result Buffer, uses BlockRAMs =
208 -- ===================================================
209 {-# ANN resBuff (InitState 'initResultState) #-}                                   
210 resBuff :: 
211   OutputState -> 
212   ( Cell, Cell, Bool, (Discr, ArrayIndex, Bool)) -> 
213   (OutputState, Cell, OutputSignal)
214 resBuff (State (Outp {..})) (pipeT, pipeT_1, new_cell, (discrN, index, new_discr)) = ( State Outp { valid_mem = valid_mem'
215                                                                                                   , mem1      = mem1'
216                                                                                                   , mem2      = mem2'
217                                                                                                   , lutm      = lutm'
218                                                                                                   }
219                                                                                      , res_mem_out, output)
220   where
221     addrT                         = discr pipeT
222     addrT_1                       = discr pipeT_1
223     -- Purge completely reduced results from the system
224     clean_mem   | new_discr       = replace valid_mem discrN False
225                 | otherwise       = valid_mem
226     -- If a partial is fed  back to the pipeline, make its location invalid   
227     valid_mem'  | new_cell        = replace clean_mem addrT True
228                 | otherwise       = replace clean_mem addrT False
229     -- Two parrallel memories with the same write addr, but diff rdaddr for partial res and other for complete res
230     (mem1', partial)              = blockRAM mem1 (value pipeT) addrT   addrT new_cell
231     (mem2', complete)             = blockRAM mem2 (value pipeT) discrN  addrT new_cell
232     -- Lut maps discriminators to array index
233     (lutm', lut_out)              = blockRAM lutm index discrN discrN new_discr
234     res_mem_out                   = (valid_mem!addrT_1, (partial,addrT))
235     -- Output value to the system once a discriminator is reused
236     output                        = ((complete,lut_out), new_discr && (valid_mem!discrN))
237
238 -- ================================================================
239 -- = Controller guides correct inputs to the floating point adder =
240 -- ================================================================
241 controller :: 
242   (Cell, Cell, Cell, Cell) -> 
243   (Cell, Cell, Shift, Bool)
244 controller (inp1, inp2, pipeT, from_res_mem) = (arg1, arg2, shift, to_res_mem)
245   where
246     (arg1, arg2, shift, to_res_mem)
247       | valid pipeT && valid from_res_mem                          = (pipeT   , from_res_mem            , 0, False)
248       | valid pipeT && valid inp1 && discr pipeT == discr inp1     = (pipeT   , inp1                    , 1, False)
249       | valid inp1  && valid inp2 && discr inp1 == discr inp2      = (inp1    , inp2                    , 2, valid pipeT)
250       | valid inp1                                                 = (inp1    , (True, (0, discr inp1)) , 1, valid pipeT)
251       | otherwise                                                  = (notValid, notValid                , 0, valid pipeT)
252
253 -- =============================================
254 -- = Reducer: Wrap up all the above components =
255 -- =============================================
256 {-# ANN reducer TopEntity #-}
257 {-# ANN reducer (InitState 'initReducerState) #-}   
258 reducer :: 
259   ReducerState -> 
260   (DataInt, ArrayIndex) -> 
261   (ReducerState, OutputSignal)
262 reducer (State (Reducer {..})) (data_in, index) = ( State Reducer { discrState  = discrState'
263                                                                   , rippleState = rippleState'
264                                                                   , blockState  = blockState'
265                                                                   , pipeState   = pipeState'
266                                                                   , resultState = resultState'
267                                                                   , pipeline    = pipeline'
268                                                                   }
269                                                   , last outPipe)
270   where
271     -- Discriminator
272     (discrState' , discrN, new_discr)       = discriminator discrState (snd (last inPipe))
273     -- InputBuffer
274     (rippleState' , (inp1V, inp1I), (inp2V, inp2I)) = rippleBuffer rippleState (discrN, shift)
275     (blockState', inp1D, inp2D)             = blockBuffer blockState ((fst (last inPipe)), shift)
276     (inp1,inp2)                             = ((inp1V,(inp1D,inp1I)),(inp2V,(inp2D,inp2I))) 
277     -- FP Adder    
278     (pipeState'  , (pipeT_1, pipeT))        = fpAdder pipeState (arg1, arg2)
279     -- Result Buffer
280     (resultState', from_res_mem, output')   = resBuff resultState (pipeT, pipeT_1, to_res_mem, last discrO)
281     -- Controller
282     (arg1,arg2,shift,to_res_mem)            = controller (inp1, inp2, pipeT, (valT, snd from_res_mem))
283     -- Optimizations/Pipelining
284     valT_1  | discr pipeT == discr pipeT_1  = not (valid from_res_mem)
285             | otherwise                     = valid from_res_mem
286     (discrO, valT, inPipe , outPipe)        = pipeline    
287     pipeline'                               = ( (discrN, index, new_discr) >> discrO
288                                               , valT_1
289                                               , (data_in, index) >> inPipe
290                                               , output' >> outPipe
291                                               )
292
293
294 -- ========================
295 -- = Initial State values =
296 -- ========================
297 initDiscrState :: DiscrRecord
298 initDiscrState = DiscrR { prev_index = 255
299                         , cur_discr  = 127
300                         }
301                                          
302 initRippleState :: Vector (AdderDepthPL :+: D1) (CellType, Discr)
303 initRippleState = copy (False, 0)
304
305 initBlockState :: (SizedWord D4, SizedWord D4, SizedWord D4)
306 initBlockState = (0,1,0)
307                      
308 initPipeState :: 
309   ((DataInt,DataInt,DataInt,DataInt)
310   , Vector AdderDepthPL (CellType, Discr)
311   )
312 initPipeState = ((0,0,0,0),copy (False, 0))     
313
314 initAdderState :: Vector AdderDepth DataInt
315 initAdderState = copy 0
316
317 initResultState :: RAM DiscrRange CellType     
318 initResultState = copy False
319
320 initReducerState :: 
321   ( Vector AdderDepth (Discr, ArrayIndex, Bool)
322   , CellType
323   , Vector D2 (DataInt, ArrayIndex)
324   , Vector D2 OutputSignal
325   )
326 initReducerState = (copy (0, 0, False), False, copy (0,0), copy ((0,0), False))
327
328 initstate :: ReducerState
329 initstate = State ( Reducer { discrState  = State initDiscrState
330                             , rippleState = State initRippleState
331                             , blockState  = State Block { ptrs = initBlockState
332                                                         , buf1 = State (copy 0)
333                                                         , buf2 = State (copy 0)
334                                                         }
335                             , pipeState   = State (initPipeState, State initAdderState)
336                             , resultState = State Outp { valid_mem = initResultState 
337                                                        , mem1      = State (copy 0)
338                                                        , mem2      = State (copy 0)
339                                                        , lutm      = State (copy 0)
340                                                        }
341                             , pipeline    = initReducerState
342                             })
343
344 -- ==================
345 -- = Test Functions =
346 -- ==================          
347 run func state [] = []
348 run func state (i:input) = o:out
349   where
350     (state', o) = func state i
351     out         = run func state' input
352
353 runReducer =  ( reduceroutput
354               , validoutput
355               , equal
356               , allEqual
357               )
358   where
359     -- input = randominput 900 7
360     input  = siminput
361     istate = initstate
362     output = run reducer istate input
363     reduceroutput = P.map fst (filter (\x -> (snd x)) output)
364     validoutput   = [P.foldl (+) 0 
365                       (P.map (\z -> toInteger (fst z)) 
366                         (filter (\x -> (snd x) == i) input)) | i <- [0..30]]
367     equal = [validoutput!!i == toInteger (fst (reduceroutput!!i)) | 
368               i <- [0..30]]
369     allEqual = foldl1 (&&) equal
370
371 siminput :: [(DataInt, ArrayIndex)]
372 siminput =  [(1,0),(5,1),(12,1),(4,2),(9,2),(2,2),(13,2),(2,2),(6,2),(1,2),(12,2),(13,3),(6,3),(11,3),(2,3),(11,3),(5,4),(11,4),(1,4),(7,4),(3,4),(4,4),(5,5),(8,5),(8,5),(13,5),(10,5),(7,5),(9,6),(9,6),(3,6),(11,6),(14,6),(13,6),(10,6),(4,7),(15,7),(13,7),(10,7),(10,7),(6,7),(15,7),(9,7),(1,7),(7,7),(15,7),(3,7),(13,7),(7,8),(3,9),(13,9),(2,10),(9,11),(10,11),(9,11),(2,11),(14,12),(14,12),(12,13),(7,13),(9,13),(7,14),(14,15),(5,16),(6,16),(14,16),(11,16),(5,16),(5,16),(7,17),(1,17),(13,17),(10,18),(15,18),(12,18),(14,19),(13,19),(2,19),(3,19),(14,19),(9,19),(11,19),(2,19),(2,20),(3,20),(13,20),(3,20),(1,20),(9,20),(10,20),(4,20),(8,21),(4,21),(8,21),(4,21),(13,21),(3,21),(7,21),(12,21),(7,21),(13,21),(3,21),(1,22),(13,23),(9,24),(14,24),(4,24),(13,25),(6,26),(12,26),(4,26),(15,26),(3,27),(6,27),(5,27),(6,27),(12,28),(2,28),(8,28),(5,29),(4,29),(1,29),(2,29),(9,29),(10,29),(4,30),(6,30),(14,30),(11,30),(15,31),(15,31),(2,31),(14,31),(9,32),(3,32),(4,32),(6,33),(15,33),(1,33),(15,33),(4,33),(3,33),(8,34),(12,34),(14,34),(15,34),(4,35),(4,35),(12,35),(14,35),(3,36),(14,37),(3,37),(1,38),(15,39),(13,39),(13,39),(1,39),(5,40),(10,40),(14,40),(1,41),(6,42),(8,42),(11,42),(11,43),(2,43),(11,43),(8,43),(12,43),(15,44),(14,44),(6,44),(8,44),(9,45),(5,45),(12,46),(6,46),(5,46),(4,46),(2,46),(9,47),(7,48),(1,48),(3,48),(10,48),(1,48),(6,48),(6,48),(11,48),(11,48),(8,48),(14,48),(5,48),(11,49),(1,49),(3,49),(11,49),(8,49),(3,50),(8,51),(9,52),(7,52),(7,53),(8,53),(10,53),(11,53),(14,54),(11,54),(4,54),(6,55),(11,55),(5,56),(7,56),(6,56),(2,56),(4,56),(12,56),(4,57),(12,57),(2,57),(14,57),(9,57),(12,57),(5,57),(11,57),(7,58),(14,58),(2,58),(10,58),(2,58),(14,58),(7,58),(12,58),(1,58),(11,59),(8,59),(2,59),(14,59),(6,59),(6,59),(6,59),(14,59),(4,59),(1,59),(4,60),(14,60),(6,60),(4,60),(8,60),(12,60),(1,60),(8,60),(8,60),(13,60),(10,61),(11,61),(6,61),(14,61),(10,61),(3,62),(10,62),(7,62),(14,62),(10,62),(4,62),(6,62),(1,62),(3,63),(3,63),(1,63),(1,63),(15,63),(7,64),(1,65),(4,65),(11,66),(3,66),(13,66),(2,67),(2,67),(5,68),(15,68),(11,68),(8,68),(4,69),(11,69),(12,69),(8,69),(7,70),(9,70),(6,70),(9,70),(11,70),(14,70),(5,71),(7,71),(11,72),(5,72),(3,72),(2,72),(1,73),(13,73),(9,73),(14,73),(5,73),(6,73),(14,73),(13,73),(3,74),(13,74),(3,75),(14,75),(10,75),(5,75),(3,75),(8,75),(9,76),(7,76),(10,76),(10,76),(8,77),(10,77),(11,77),(8,77),(2,77),(9,77),(9,77),(12,77),(4,77),(14,77),(10,77),(7,77),(3,77),(10,78),(8,79),(14,79),(11,80),(15,81),(6,81),(4,82),(6,82),(1,82),(12,83),(6,83),(11,83),(12,83),(15,83),(13,83),(1,84),(2,84),(11,84),(5,84),(2,84),(2,84),(3,84),(4,85),(6,86),(5,86),(15,86),(8,86),(9,86),(9,87),(9,87),(12,87),(4,87),(13,88),(14,88),(10,88),(11,88),(7,88),(4,88),(9,88),(1,88),(4,88),(4,88),(12,88),(8,89),(3,89),(10,89),(10,89),(5,89),(14,89),(11,89),(10,89),(5,90),(6,90),(10,90),(9,90),(8,90),(10,90),(5,90),(11,90),(6,90),(10,90),(7,90),(3,91),(7,91),(5,91),(15,91),(4,91),(6,91),(8,91),(1,91),(8,91),(12,92),(8,93),(9,93),(12,94),(8,94),(5,94),(11,95),(13,95),(5,96),(12,96),(8,96),(4,96),(7,97),(6,97),(4,97),(1,98),(5,98),(12,98),(13,99),(7,100),(12,100),(4,100),(10,100),(2,101),(3,101),(14,101),(12,101),(5,101),(2,101),(14,101),(15,101),(7,102),(13,102),(5,102),(7,102),(4,102),(8,102),(12,103),(15,103),(2,103),(2,103),(6,103),(6,103),(1,104),(14,104),(15,105),(3,105),(13,105),(1,105),(8,105),(8,105),(15,105),(13,105),(13,105),(6,105),(9,105),(6,106),(14,107),(12,107),(7,108),(7,108),(6,109),(11,109),(14,110),(8,111),(5,111),(15,111),(14,111),(3,111),(13,112),(12,112),(5,112),(10,112),(7,112),(5,113),(3,113),(2,113),(1,113),(15,113),(8,113),(10,113),(3,114),(6,114),(15,114),(4,115),(8,115),(1,115),(12,115),(5,115),(6,116),(2,116),(13,116),(12,116),(6,116),(10,117),(8,117),(14,118),(10,118),(3,118),(15,119),(6,119),(6,120),(5,121),(8,121),(4,122),(1,122),(9,123),(12,123),(6,124),(10,124),(2,124),(11,124),(9,125),(8,126),(10,126),(11,126),(14,126),(2,126),(5,126),(7,126),(3,127),(12,127),(15,128),(4,128),(1,129),(14,129),(8,129),(9,129),(6,129),(1,130),(11,130),(2,130),(13,130),(14,131),(2,131),(15,131),(4,131),(15,131),(8,131),(3,131),(8,132),(1,132),(13,132),(8,132),(5,132),(11,132),(14,132),(14,132),(4,132),(14,132),(5,132),(11,133),(1,133),(15,133),(8,133),(12,133),(8,134),(14,135),(11,136),(9,137),(3,137),(15,138),(1,138),(1,139),(4,139),(3,140),(10,140),(8,141),(12,141),(4,141),(12,141),(13,141),(10,141),(4,142),(6,142),(15,142),(4,142),(2,143),(14,143),(5,143),(10,143),(8,143),(9,143),(3,143),(11,143),(6,144),(3,145),(9,145),(10,145),(6,145),(11,145),(4,145),(13,145),(5,145),(4,145),(1,145),(3,145),(15,145),(14,146),(11,146),(9,146),(9,146),(10,146),(9,146),(3,146),(2,146),(10,146),(6,146),(7,146),(3,147),(4,147),(15,147),(11,147),(15,147),(1,147),(15,147),(14,147),(15,147),(5,147),(15,147),(4,147),(2,148),(12,149),(12,150),(10,150),(1,150),(7,151),(4,151),(14,151),(15,151),(5,152),(11,153),(3,153),(1,153),(1,153),(12,153),(1,154),(1,155),(11,155),(8,155),(3,155),(8,155),(8,155),(2,155),(9,156),(6,156),(12,156),(1,156),(3,156),(8,156),(5,157),(9,157),(12,157),(6,157),(8,158),(15,159),(2,159),(10,160),(10,160),(2,160),(6,160),(10,160),(8,160),(13,160),(12,161),(15,161),(14,161),(10,161),(13,161),(14,161),(3,161),(2,161),(1,161),(11,161),(7,161),(8,161),(4,162),(9,163),(3,164),(5,164),(9,164),(9,165),(7,165),(1,165),(6,166),(14,166),(3,166),(14,166),(4,166),(14,167),(5,167),(13,167),(12,167),(13,168),(9,168)]