Do not be overzealous with inlining results of polymorphic functions
[matthijs/master-project/cλash.git] / reducer.hs
1 {-# LANGUAGE TypeOperators, TemplateHaskell, FlexibleContexts, TypeFamilies, 
2              ScopedTypeVariables, RecordWildCards #-}
3 module Reducer where
4
5 import qualified Prelude as P
6 import CLasH.HardwareTypes hiding ((>>))
7 import CLasH.Translator.Annotations
8
9 -- =======================================
10 -- = System size configuration variables =
11 -- =======================================
12 type DataSize       = D64
13 type IndexSize      = D16
14 type DiscrSize      = D7
15 type AdderDepth     = D12
16
17 -- Derived configuration variables
18 type DiscrRange     = (Pow2 DiscrSize) :-: D1
19 type AdderDepthPL   = AdderDepth :+: D3
20
21 -- =================
22 -- = Type Aliasses =
23 -- =================
24 type Shift          = Index D2
25 type DataInt        = Signed DataSize
26 type ArrayIndex     = Unsigned IndexSize
27 type Discr          = Index DiscrRange
28 type OutputSignal   = ((DataInt, ArrayIndex), Bool)
29
30 -- =================================
31 -- = Cell Definition and Accessors =
32 -- =================================
33 type CellType       = Bool                     
34 type Cell           = (CellType, (DataInt, Discr))
35
36 valid :: Cell -> Bool
37 valid (x, _) = x
38
39 value :: Cell -> DataInt
40 value (_, (v, _)) = v
41
42 discr :: Cell -> Discr
43 discr (_, (_, d)) = d
44
45 notValid :: Cell
46 notValid = (False, (0, 0))
47
48 -- ====================
49 -- = Helper functions =
50 -- ====================
51 v << e = shiftr v e
52 e >> v = shiftl v e
53
54 -- =======================
55 -- = Reducer State types =
56 -- =======================
57 data DiscrRecord =
58   DiscrR { prev_index  ::  ArrayIndex
59          , cur_discr   ::  Unsigned DiscrSize
60          }
61 type DiscrState = State DiscrRecord
62                             
63 type RippleState = 
64   State (Vector (AdderDepthPL :+: D1) (CellType, Discr))
65
66 data BlockRecord = 
67   Block { ptrs    ::  (Unsigned D4, Unsigned D4, Unsigned D4)
68         , buf1    ::  MemState AdderDepthPL DataInt
69         , buf2    ::  MemState AdderDepthPL DataInt
70         }
71 type BlockState = State BlockRecord
72
73 type FpAdderState = 
74   State ( ( (DataInt, DataInt, DataInt, DataInt)  -- Buffer input and double buffer output of the FPAdder
75           , Vector AdderDepthPL (CellType, Discr) -- Validbits & discriminators of values in the FP pipeline
76           )
77         , FpPlaceholder
78         )
79
80 type FpPlaceholder = 
81   State (Vector AdderDepth DataInt)
82
83 data OutputRecord = 
84   Outp { valid_mem :: RAM DiscrRange CellType
85        , mem1      :: MemState DiscrRange DataInt
86        , mem2      :: MemState DiscrRange DataInt
87        , lutm      :: MemState DiscrRange ArrayIndex
88        }                            
89 type OutputState = State OutputRecord
90
91 data ReducerRecord = 
92   Reducer { discrState  :: DiscrState
93           , rippleState :: RippleState
94           , blockState  :: BlockState
95           , pipeState   :: FpAdderState
96           , resultState :: OutputState          
97           , pipeline    :: ( Vector AdderDepth (Discr, ArrayIndex, Bool) -- Buffer link between discriminator and Result buffer                                                 
98                            , CellType                                    -- Buffer Valid bit of the resultbuffer at T+1                                                 
99                            , Vector D2 (DataInt, ArrayIndex)             -- Buffer Input (to encourage retiming)
100                            , Vector D2 OutputSignal                      -- Buffer Output (to encourage retiming)
101                            )
102           }                              
103 type ReducerState   = State ReducerRecord
104
105 -- ===========================================================
106 -- = Discrimintor: Hands out new discriminator to the system =
107 -- ===========================================================
108 {-# ANN discriminator (InitState 'initDiscrState) #-}
109 discriminator :: 
110   DiscrState -> 
111   ArrayIndex -> 
112   (DiscrState, Discr, Bool)
113 discriminator (State (DiscrR {..})) index = ( State DiscrR { prev_index = index
114                                                            , cur_discr  = cur_discr'
115                                                            }
116                                             , discr
117                                             , new_discr
118                                             )
119   where
120     new_discr               = index /= prev_index
121     cur_discr'  | new_discr = cur_discr + 1
122                 | otherwise = cur_discr
123     discr                   = fromUnsigned cur_discr'
124
125 -- ======================================================
126 -- = Input Buffer: Buffers incomming inputs when needed =
127 -- ======================================================
128 {-# ANN rippleBuffer (InitState 'initRippleState) #-}
129 rippleBuffer :: 
130   RippleState ->
131   (Discr, Shift) ->
132   (RippleState, (CellType, Discr), (CellType, Discr))
133 rippleBuffer (State buf) (inp, shift) = (State buf', out1, out2)
134   where
135     -- Write value
136     next_valids              = (map (\(a, b) -> a) buf) << True
137     buf''                    = zipWith selects buf next_valids
138     selects cell next_valid  = if (not (fst cell)) && next_valid then
139                                  (True, inp)
140                                else
141                                  cell
142     -- Shift values                            
143     buf'        | shift == 2 = (False, 0) >> ((False, 0) >> buf'')
144                 | shift == 1 = (False, 0) >> buf''
145                 | otherwise  = buf''
146     -- Read values
147     out1                     = last buf
148     out2                     = last (init buf)
149     
150 {-# ANN blockBuffer (InitState 'initBlockState) #-}
151 blockBuffer :: 
152   BlockState ->
153   (DataInt, Shift) ->
154   (BlockState, DataInt, DataInt)
155 blockBuffer (State (Block {..})) (inp, shift) = ( State Block { ptrs = ptrs'
156                                                               , buf1 = buf1'
157                                                               , buf2 = buf2'
158                                                               }
159                                                 , out1, out2)
160   where 
161     -- Do some state (un)packing
162     (rd_ptr1, rd_ptr2, wr_ptr) = ptrs
163     ptrs'                      = (rd_ptr1', rd_ptr2', wr_ptr')
164     -- Update pointers               
165     count                      = fromIndex shift
166     (rd_ptr1', rd_ptr2')       = (rd_ptr1 + count, rd_ptr2 + count)
167     wr_ptr'                    = wr_ptr + 1
168     -- Write & Read from RAMs
169     (buf1', out1)              = blockRAM buf1 inp (fromUnsigned rd_ptr1) (fromUnsigned wr_ptr) True
170     (buf2', out2)              = blockRAM buf2 inp (fromUnsigned rd_ptr2) (fromUnsigned wr_ptr) True
171     
172 -- ============================================
173 -- = Simulated pipelined floating point adder =
174 -- ============================================
175 {-# ANN fpAdder (InitState 'initPipeState) #-}
176 fpAdder :: 
177   FpAdderState -> 
178   (Cell, Cell) -> 
179   (FpAdderState, (Cell, Cell))         
180 fpAdder (State ((buffer, pipe), adderState)) (arg1, arg2) = (State ((buffer', pipe'), adderState'), (pipeT_1, pipeT))
181   where
182     -- Do some state (un)packing
183     (a1,a2,dataT_1,dataT)   = buffer
184     buffer'                 = (value arg1, value arg2, adderOut, dataT_1)
185     -- placeholder adder
186     (adderState', adderOut) = fpPlaceholder adderState (a1, a2)
187     -- Save corresponding indexes and valid bits      
188     pipe'                   = (valid arg1, discr arg1) >> pipe
189     -- Produce output for time T and T+1
190     pipeEndT                = last pipe
191     pipeEndT_1              = last (init pipe)
192     pipeT                   = (fst pipeEndT, (dataT, snd pipeEndT))
193     pipeT_1                 = (fst pipeEndT_1,(dataT_1,snd pipeEndT_1))
194
195 {-# ANN fpPlaceholder (InitState 'initAdderState) #-}
196 fpPlaceholder :: FpPlaceholder -> (DataInt, DataInt) -> (FpPlaceholder, DataInt)
197 fpPlaceholder (State pipe) (arg1, arg2) = (State pipe', pipe_out)
198   where
199     pipe'       = (arg1 + arg2) +> init pipe
200     pipe_out    = last pipe
201
202 -- ===================================================
203 -- = Optimized Partial Result Buffer, uses BlockRAMs =
204 -- ===================================================
205 {-# ANN resBuff (InitState 'initResultState) #-}                                   
206 resBuff :: 
207   OutputState -> 
208   ( Cell, Cell, Bool, (Discr, ArrayIndex, Bool)) -> 
209   (OutputState, Cell, OutputSignal)
210 resBuff (State (Outp {..})) (pipeT, pipeT_1, new_cell, (discrN, index, new_discr)) = ( State Outp { valid_mem = valid_mem'
211                                                                                                   , mem1      = mem1'
212                                                                                                   , mem2      = mem2'
213                                                                                                   , lutm      = lutm'
214                                                                                                   }
215                                                                                      , res_mem_out, output)
216   where
217     addrT                         = discr pipeT
218     addrT_1                       = discr pipeT_1
219     -- Purge completely reduced results from the system
220     clean_mem   | new_discr       = replace valid_mem discrN False
221                 | otherwise       = valid_mem
222     -- If a partial is fed  back to the pipeline, make its location invalid   
223     valid_mem'  | new_cell        = replace clean_mem addrT True
224                 | otherwise       = replace clean_mem addrT False
225     -- Two parrallel memories with the same write addr, but diff rdaddr for partial res and other for complete res
226     (mem1', partial)              = blockRAM mem1 (value pipeT) addrT   addrT new_cell
227     (mem2', complete)             = blockRAM mem2 (value pipeT) discrN  addrT new_cell
228     -- Lut maps discriminators to array index
229     (lutm', lut_out)              = blockRAM lutm index discrN discrN new_discr
230     res_mem_out                   = (valid_mem!addrT_1, (partial,addrT))
231     -- Output value to the system once a discriminator is reused
232     output                        = ((complete,lut_out), new_discr && (valid_mem!discrN))
233
234 -- ================================================================
235 -- = Controller guides correct inputs to the floating point adder =
236 -- ================================================================
237 controller :: 
238   (Cell, Cell, Cell, Cell) -> 
239   (Cell, Cell, Shift, Bool)
240 controller (inp1, inp2, pipeT, from_res_mem) = (arg1, arg2, shift, to_res_mem)
241   where
242     (arg1, arg2, shift, to_res_mem)
243       | valid pipeT && valid from_res_mem                          = (pipeT   , from_res_mem            , 0, False)
244       | valid pipeT && valid inp1 && discr pipeT == discr inp1     = (pipeT   , inp1                    , 1, False)
245       | valid inp1  && valid inp2 && discr inp1 == discr inp2      = (inp1    , inp2                    , 2, valid pipeT)
246       | valid inp1                                                 = (inp1    , (True, (0, discr inp1)) , 1, valid pipeT)
247       | otherwise                                                  = (notValid, notValid                , 0, valid pipeT)
248
249 -- =============================================
250 -- = Reducer: Wrap up all the above components =
251 -- =============================================
252 {-# ANN reducer TopEntity #-}
253 {-# ANN reducer (InitState 'initReducerState) #-}   
254 reducer :: 
255   ReducerState -> 
256   (DataInt, ArrayIndex) -> 
257   (ReducerState, OutputSignal)
258 reducer (State (Reducer {..})) (data_in, index) = ( State Reducer { discrState  = discrState'
259                                                                   , rippleState = rippleState'
260                                                                   , blockState  = blockState'
261                                                                   , pipeState   = pipeState'
262                                                                   , resultState = resultState'
263                                                                   , pipeline    = pipeline'
264                                                                   }
265                                                   , last outPipe)
266   where
267     -- Discriminator
268     (discrState' , discrN, new_discr)       = discriminator discrState (snd (last inPipe))
269     -- InputBuffer
270     (rippleState' , (inp1V, inp1I), (inp2V, inp2I)) = rippleBuffer rippleState (discrN, shift)
271     (blockState', inp1D, inp2D)             = blockBuffer blockState ((fst (last inPipe)), shift)
272     (inp1,inp2)                             = ((inp1V,(inp1D,inp1I)),(inp2V,(inp2D,inp2I))) 
273     -- FP Adder    
274     (pipeState'  , (pipeT_1, pipeT))        = fpAdder pipeState (arg1, arg2)
275     -- Result Buffer
276     (resultState', from_res_mem, output')   = resBuff resultState (pipeT, pipeT_1, to_res_mem, last discrO)
277     -- Controller
278     (arg1,arg2,shift,to_res_mem)            = controller (inp1, inp2, pipeT, (valT, snd from_res_mem))
279     -- Optimizations/Pipelining
280     valT_1  | discr pipeT == discr pipeT_1  = not (valid from_res_mem)
281             | otherwise                     = valid from_res_mem
282     (discrO, valT, inPipe , outPipe)        = pipeline    
283     pipeline'                               = ( (discrN, index, new_discr) >> discrO
284                                               , valT_1
285                                               , (data_in, index) >> inPipe
286                                               , output' >> outPipe
287                                               )
288
289
290 -- ========================
291 -- = Initial State values =
292 -- ========================
293 initDiscrState :: DiscrRecord
294 initDiscrState = DiscrR { prev_index = 255
295                         , cur_discr  = 127
296                         }
297                                          
298 initRippleState :: Vector (AdderDepthPL :+: D1) (CellType, Discr)
299 initRippleState = copy (False, 0)
300
301 initBlockState :: (Unsigned D4, Unsigned D4, Unsigned D4)
302 initBlockState = (0,1,0)
303                      
304 initPipeState :: 
305   ((DataInt,DataInt,DataInt,DataInt)
306   , Vector AdderDepthPL (CellType, Discr)
307   )
308 initPipeState = ((0,0,0,0),copy (False, 0))     
309
310 initAdderState :: Vector AdderDepth DataInt
311 initAdderState = copy 0
312
313 initResultState :: RAM DiscrRange CellType     
314 initResultState = copy False
315
316 initReducerState :: 
317   ( Vector AdderDepth (Discr, ArrayIndex, Bool)
318   , CellType
319   , Vector D2 (DataInt, ArrayIndex)
320   , Vector D2 OutputSignal
321   )
322 initReducerState = (copy (0, 0, False), False, copy (0,0), copy ((0,0), False))
323
324 initstate :: ReducerState
325 initstate = State ( Reducer { discrState  = State initDiscrState
326                             , rippleState = State initRippleState
327                             , blockState  = State Block { ptrs = initBlockState
328                                                         , buf1 = State (copy 0)
329                                                         , buf2 = State (copy 0)
330                                                         }
331                             , pipeState   = State (initPipeState, State initAdderState)
332                             , resultState = State Outp { valid_mem = initResultState 
333                                                        , mem1      = State (copy 0)
334                                                        , mem2      = State (copy 0)
335                                                        , lutm      = State (copy 0)
336                                                        }
337                             , pipeline    = initReducerState
338                             })
339
340 -- ==================
341 -- = Test Functions =
342 -- ==================          
343 run func state [] = []
344 run func state (i:input) = o:out
345   where
346     (state', o) = func state i
347     out         = run func state' input
348
349 runReducer =  ( reduceroutput
350               , validoutput
351               , equal
352               , allEqual
353               )
354   where
355     -- input = randominput 900 7
356     input  = siminput
357     istate = initstate
358     output = run reducer istate input
359     reduceroutput = P.map fst (filter (\x -> (snd x)) output)
360     validoutput   = [P.foldl (+) 0 
361                       (P.map (\z -> toInteger (fst z)) 
362                         (filter (\x -> (snd x) == i) input)) | i <- [0..30]]
363     equal = [validoutput!!i == toInteger (fst (reduceroutput!!i)) | 
364               i <- [0..30]]
365     allEqual = foldl1 (&&) equal
366
367 siminput :: [(DataInt, ArrayIndex)]
368 siminput =  [(1,0),(5,1),(12,1),(4,2),(9,2),(2,2),(13,2),(2,2),(6,2),(1,2),(12,2),(13,3),(6,3),(11,3),(2,3),(11,3),(5,4),(11,4),(1,4),(7,4),(3,4),(4,4),(5,5),(8,5),(8,5),(13,5),(10,5),(7,5),(9,6),(9,6),(3,6),(11,6),(14,6),(13,6),(10,6),(4,7),(15,7),(13,7),(10,7),(10,7),(6,7),(15,7),(9,7),(1,7),(7,7),(15,7),(3,7),(13,7),(7,8),(3,9),(13,9),(2,10),(9,11),(10,11),(9,11),(2,11),(14,12),(14,12),(12,13),(7,13),(9,13),(7,14),(14,15),(5,16),(6,16),(14,16),(11,16),(5,16),(5,16),(7,17),(1,17),(13,17),(10,18),(15,18),(12,18),(14,19),(13,19),(2,19),(3,19),(14,19),(9,19),(11,19),(2,19),(2,20),(3,20),(13,20),(3,20),(1,20),(9,20),(10,20),(4,20),(8,21),(4,21),(8,21),(4,21),(13,21),(3,21),(7,21),(12,21),(7,21),(13,21),(3,21),(1,22),(13,23),(9,24),(14,24),(4,24),(13,25),(6,26),(12,26),(4,26),(15,26),(3,27),(6,27),(5,27),(6,27),(12,28),(2,28),(8,28),(5,29),(4,29),(1,29),(2,29),(9,29),(10,29),(4,30),(6,30),(14,30),(11,30),(15,31),(15,31),(2,31),(14,31),(9,32),(3,32),(4,32),(6,33),(15,33),(1,33),(15,33),(4,33),(3,33),(8,34),(12,34),(14,34),(15,34),(4,35),(4,35),(12,35),(14,35),(3,36),(14,37),(3,37),(1,38),(15,39),(13,39),(13,39),(1,39),(5,40),(10,40),(14,40),(1,41),(6,42),(8,42),(11,42),(11,43),(2,43),(11,43),(8,43),(12,43),(15,44),(14,44),(6,44),(8,44),(9,45),(5,45),(12,46),(6,46),(5,46),(4,46),(2,46),(9,47),(7,48),(1,48),(3,48),(10,48),(1,48),(6,48),(6,48),(11,48),(11,48),(8,48),(14,48),(5,48),(11,49),(1,49),(3,49),(11,49),(8,49),(3,50),(8,51),(9,52),(7,52),(7,53),(8,53),(10,53),(11,53),(14,54),(11,54),(4,54),(6,55),(11,55),(5,56),(7,56),(6,56),(2,56),(4,56),(12,56),(4,57),(12,57),(2,57),(14,57),(9,57),(12,57),(5,57),(11,57),(7,58),(14,58),(2,58),(10,58),(2,58),(14,58),(7,58),(12,58),(1,58),(11,59),(8,59),(2,59),(14,59),(6,59),(6,59),(6,59),(14,59),(4,59),(1,59),(4,60),(14,60),(6,60),(4,60),(8,60),(12,60),(1,60),(8,60),(8,60),(13,60),(10,61),(11,61),(6,61),(14,61),(10,61),(3,62),(10,62),(7,62),(14,62),(10,62),(4,62),(6,62),(1,62),(3,63),(3,63),(1,63),(1,63),(15,63),(7,64),(1,65),(4,65),(11,66),(3,66),(13,66),(2,67),(2,67),(5,68),(15,68),(11,68),(8,68),(4,69),(11,69),(12,69),(8,69),(7,70),(9,70),(6,70),(9,70),(11,70),(14,70),(5,71),(7,71),(11,72),(5,72),(3,72),(2,72),(1,73),(13,73),(9,73),(14,73),(5,73),(6,73),(14,73),(13,73),(3,74),(13,74),(3,75),(14,75),(10,75),(5,75),(3,75),(8,75),(9,76),(7,76),(10,76),(10,76),(8,77),(10,77),(11,77),(8,77),(2,77),(9,77),(9,77),(12,77),(4,77),(14,77),(10,77),(7,77),(3,77),(10,78),(8,79),(14,79),(11,80),(15,81),(6,81),(4,82),(6,82),(1,82),(12,83),(6,83),(11,83),(12,83),(15,83),(13,83),(1,84),(2,84),(11,84),(5,84),(2,84),(2,84),(3,84),(4,85),(6,86),(5,86),(15,86),(8,86),(9,86),(9,87),(9,87),(12,87),(4,87),(13,88),(14,88),(10,88),(11,88),(7,88),(4,88),(9,88),(1,88),(4,88),(4,88),(12,88),(8,89),(3,89),(10,89),(10,89),(5,89),(14,89),(11,89),(10,89),(5,90),(6,90),(10,90),(9,90),(8,90),(10,90),(5,90),(11,90),(6,90),(10,90),(7,90),(3,91),(7,91),(5,91),(15,91),(4,91),(6,91),(8,91),(1,91),(8,91),(12,92),(8,93),(9,93),(12,94),(8,94),(5,94),(11,95),(13,95),(5,96),(12,96),(8,96),(4,96),(7,97),(6,97),(4,97),(1,98),(5,98),(12,98),(13,99),(7,100),(12,100),(4,100),(10,100),(2,101),(3,101),(14,101),(12,101),(5,101),(2,101),(14,101),(15,101),(7,102),(13,102),(5,102),(7,102),(4,102),(8,102),(12,103),(15,103),(2,103),(2,103),(6,103),(6,103),(1,104),(14,104),(15,105),(3,105),(13,105),(1,105),(8,105),(8,105),(15,105),(13,105),(13,105),(6,105),(9,105),(6,106),(14,107),(12,107),(7,108),(7,108),(6,109),(11,109),(14,110),(8,111),(5,111),(15,111),(14,111),(3,111),(13,112),(12,112),(5,112),(10,112),(7,112),(5,113),(3,113),(2,113),(1,113),(15,113),(8,113),(10,113),(3,114),(6,114),(15,114),(4,115),(8,115),(1,115),(12,115),(5,115),(6,116),(2,116),(13,116),(12,116),(6,116),(10,117),(8,117),(14,118),(10,118),(3,118),(15,119),(6,119),(6,120),(5,121),(8,121),(4,122),(1,122),(9,123),(12,123),(6,124),(10,124),(2,124),(11,124),(9,125),(8,126),(10,126),(11,126),(14,126),(2,126),(5,126),(7,126),(3,127),(12,127),(15,128),(4,128),(1,129),(14,129),(8,129),(9,129),(6,129),(1,130),(11,130),(2,130),(13,130),(14,131),(2,131),(15,131),(4,131),(15,131),(8,131),(3,131),(8,132),(1,132),(13,132),(8,132),(5,132),(11,132),(14,132),(14,132),(4,132),(14,132),(5,132),(11,133),(1,133),(15,133),(8,133),(12,133),(8,134),(14,135),(11,136),(9,137),(3,137),(15,138),(1,138),(1,139),(4,139),(3,140),(10,140),(8,141),(12,141),(4,141),(12,141),(13,141),(10,141),(4,142),(6,142),(15,142),(4,142),(2,143),(14,143),(5,143),(10,143),(8,143),(9,143),(3,143),(11,143),(6,144),(3,145),(9,145),(10,145),(6,145),(11,145),(4,145),(13,145),(5,145),(4,145),(1,145),(3,145),(15,145),(14,146),(11,146),(9,146),(9,146),(10,146),(9,146),(3,146),(2,146),(10,146),(6,146),(7,146),(3,147),(4,147),(15,147),(11,147),(15,147),(1,147),(15,147),(14,147),(15,147),(5,147),(15,147),(4,147),(2,148),(12,149),(12,150),(10,150),(1,150),(7,151),(4,151),(14,151),(15,151),(5,152),(11,153),(3,153),(1,153),(1,153),(12,153),(1,154),(1,155),(11,155),(8,155),(3,155),(8,155),(8,155),(2,155),(9,156),(6,156),(12,156),(1,156),(3,156),(8,156),(5,157),(9,157),(12,157),(6,157),(8,158),(15,159),(2,159),(10,160),(10,160),(2,160),(6,160),(10,160),(8,160),(13,160),(12,161),(15,161),(14,161),(10,161),(13,161),(14,161),(3,161),(2,161),(1,161),(11,161),(7,161),(8,161),(4,162),(9,163),(3,164),(5,164),(9,164),(9,165),(7,165),(1,165),(6,166),(14,166),(3,166),(14,166),(4,166),(14,167),(5,167),(13,167),(12,167),(13,168),(9,168)]