Inline all simple top level functions, not just compiler-generated ones.
[matthijs/master-project/cλash.git] / reducer.hs
1 {-# LANGUAGE TypeOperators, TemplateHaskell, FlexibleContexts, TypeFamilies, ScopedTypeVariables, RecordWildCards #-}
2 module Main where
3
4 import System.Random
5 import System.IO.Unsafe (unsafePerformIO,unsafeInterleaveIO)
6 import qualified Prelude as P
7 import CLasH.HardwareTypes
8 import CLasH.Translator.Annotations
9
10 -- =======================================
11 -- = System size configuration variables =
12 -- =======================================
13
14 type DataSize       = D8
15 type IndexSize      = D8
16 type DiscrSize      = D7
17 type DiscrRange     = D127
18 type AdderDepth     = D14
19
20 -- =================
21 -- = Type Aliasses =
22 -- =================
23
24 type Shift          = RangedWord D2
25 type DataInt        = SizedWord DataSize
26 type ArrayIndex     = SizedWord IndexSize
27 type Discr          = RangedWord DiscrRange
28
29 type OutputSignal   = ( ( DataInt
30                         , ArrayIndex
31                         )
32                       , Bool
33                       )
34
35 data CellType       = Valid | NotValid
36   deriving (Eq)
37                       
38 type Cell           = ( CellType
39                       , ( DataInt
40                         , Discr
41                         )
42                       )
43
44 notValid :: Cell
45 notValid = (NotValid,(0::DataInt,0::Discr))    
46
47 -- ================================
48 -- = Cell type accessor functions =
49 -- ================================
50
51 valid :: Cell -> Bool
52 valid (Valid, _) = True
53 valid _          = False
54
55 value :: Cell -> DataInt
56 value (_, (v, _)) = v
57
58 discr :: Cell -> Discr
59 discr (_, (_, d)) = d
60
61 -- =======================
62 -- = Reducer State types =
63 -- =======================
64
65 data DiscrRecord    = DiscrR { prev_index  ::  ArrayIndex
66                              , cur_discr   ::  SizedWord DiscrSize
67                              }
68
69 type DiscrState     = State DiscrRecord
70
71 data CircRecord     = Circ  { mem   ::  Vector (AdderDepth :+: D1) (DataInt, Discr)
72                             , rdptr ::  RangedWord AdderDepth
73                             , wrptr ::  RangedWord AdderDepth
74                             , count ::  RangedWord (AdderDepth :+: D1)
75                             }
76
77 type CircState      = State CircRecord
78
79 type FpAdderState   = State (Vector AdderDepth Cell)
80
81 data OutputRecord   = Outp  { res_mem :: RAM DiscrRange Cell
82                             , lut     :: MemState DiscrRange ArrayIndex
83                             }
84
85 type OutputState    = State OutputRecord
86
87 data OutputRecordO  = OutpO { valid_mem :: RAM DiscrRange CellType
88                             , mem1      :: MemState DiscrRange DataInt
89                             , mem2      :: MemState DiscrRange DataInt
90                             , lutm      :: MemState DiscrRange ArrayIndex
91                             }
92                             
93 type OutputStateO   = State OutputRecordO
94
95 data ReducerRecord  = Reducer { discrState  ::  DiscrState
96                               , inputState  ::  CircState
97                               , pipeState   ::  FpAdderState
98                               , resultState ::  OutputStateO
99                               }
100                               
101 type ReducerState   = State ReducerRecord
102
103 -- ===========================================================
104 -- = Discrimintor: Hands out new discriminator to the system =
105 -- ===========================================================
106 {-# ANN discriminator (InitState 'initDiscrState) #-}
107 discriminator ::  DiscrState -> (DataInt, ArrayIndex) -> (DiscrState, (DataInt, Discr), Bool)
108 discriminator (State (DiscrR {..})) (data_in, index) =  ( State ( DiscrR { prev_index = index
109                                                                          , cur_discr  = cur_discr'
110                                                                          })
111                                                                 , (data_in, discr)
112                                                                 , new_discr
113                                                                 )
114   where
115     new_discr                         = index /= prev_index
116     cur_discr'  | new_discr           = cur_discr + 1
117                 | otherwise           = cur_discr
118     discr                             = fromSizedWord cur_discr'
119
120 -- ======================================================
121 -- = Input Buffer: Buffers incomming inputs when needed =
122 -- ======================================================
123 {-# ANN circBuffer (InitState 'initCircState) #-}
124 circBuffer :: CircState ->
125               ((DataInt, Discr), Shift) ->
126               (CircState, Cell, Cell)
127 circBuffer (State (Circ {..})) (inp,shift) =  ( State ( Circ { mem   = mem' 
128                                                              , rdptr = rdptr'
129                                                              , wrptr = wrptr' 
130                                                              , count = count'
131                                                              })
132                                                       , out1, out2
133                                                       )
134   where
135     (n :: RangedWord AdderDepth)  = fromInteger (fromIntegerT (undefined :: AdderDepth))
136     (rdptr',count') | shift == 0  =                    (rdptr    , count + 1)
137                     | shift == 1  = if rdptr == 0 then (n        , count    ) else
138                                                        (rdptr - 1, count    )
139                     | otherwise   = if rdptr == 1 then (n        , count - 1) else 
140                                     if rdptr == 0 then (n - 1    , count - 1) else
141                                                        (rdptr - 2, count - 1)
142     rdptr2          | rdptr == 0  = n
143                     | otherwise   = rdptr - 1 
144     wrptr'          = if wrptr == 0 then n else wrptr - 1
145     mem'            = replace mem wrptr inp
146     out1            | count == 0  = notValid
147                     | otherwise   = (Valid,mem!rdptr)
148     out2            | count <= 1  = notValid
149                     | otherwise   = (Valid,mem!rdptr2)
150     
151 -- ============================================
152 -- = Simulated pipelined floating point adder =
153 -- ============================================
154 {-# ANN fpAdder (InitState 'initPipeState) #-}
155 fpAdder ::  FpAdderState -> (Cell, Cell) -> (FpAdderState, Cell)         
156 fpAdder (State pipe) (arg1, arg2) = (State pipe', pipe_out)
157   where
158     new_head  | valid arg1  = (Valid, ((value arg1 + value arg2), discr arg1))
159               | otherwise   = notValid
160               
161     pipe'     = new_head +> init pipe
162     pipe_out  = last pipe
163
164 -- ==============================================================
165 -- = Partial Results buffers, purges completely reduced results =
166 -- ==============================================================
167 resBuff ::  OutputState -> ( Cell, Cell, ArrayIndex, (Discr, Bool)) -> (OutputState, Cell, OutputSignal)
168 resBuff (State (Outp {..})) (pipe_out, new_cell, index, (discrN, new_discr)) = ( State ( Outp { res_mem = res_mem''
169                                                                                               , lut     = lut'
170                                                                                               })
171                                                                                , res_mem_out, output)
172   where
173     -- Purge completely reduced results from the system
174     clean_mem     | new_discr       = replace res_mem discrN notValid
175                   | otherwise       = res_mem
176     -- If a partial is fed  back to the pipeline, make its location invalid      
177     res_mem'        | valid pipe_out    = replace clean_mem (discr pipe_out) notValid
178                               | otherwise                   = clean_mem
179     -- Write a new partial to memory if it is valid
180     res_mem''       | valid new_cell    = replace res_mem' (discr new_cell) new_cell
181                               | otherwise                   = res_mem'
182     -- Output a partial if it is needed, otherwise output invalid
183     res_mem_out         | valid pipe_out        = res_mem ! (discr pipe_out)
184                               | otherwise                   = notValid
185                 -- Lut maps discriminators to array index
186     (lut', lut_out)                 = blockRAM lut index discrN discrN new_discr
187     -- Output value to the system once a discriminator is reused
188     output'                         = res_mem ! discrN
189     output                          = ( (value output', lut_out)
190                                       , new_discr && valid output'
191                                       )
192
193 -- ===================================================
194 -- = Optimized Partial Result Buffer, uses BlockRAMs =
195 -- ===================================================
196 {-# ANN resBuffO (InitState 'initResultState) #-}                                   
197 resBuffO ::  OutputStateO -> ( Cell, Cell, ArrayIndex, (Discr, Bool)) -> (OutputStateO, Cell, OutputSignal)
198 resBuffO (State (OutpO {..})) (pipe_out, new_cell, index, (discrN, new_discr)) = ( State ( OutpO { valid_mem = valid_mem'
199                                                                                                  , mem1      = mem1'
200                                                                                                  , mem2      = mem2'
201                                                                                                  , lutm      = lutm'
202                                                                                                  })
203                                                                                  , res_mem_out, output)
204   where
205     addr                          = discr pipe_out
206     -- Purge completely reduced results from the system
207     clean_mem   | new_discr       = replace valid_mem discrN NotValid
208                 | otherwise       = valid_mem
209     -- If a partial is fed  back to the pipeline, make its location invalid   
210     valid_mem'  | valid new_cell  = replace clean_mem addr Valid
211                 | otherwise       = replace clean_mem addr NotValid
212     -- Two parrallel memories with the same write addr, but diff rdaddr for partial res and other for complete res
213     (mem1', partial)              = blockRAM mem1 (value new_cell) addr   addr (valid new_cell)
214     (mem2', complete)             = blockRAM mem2 (value new_cell) discrN addr (valid new_cell)
215     -- Lut maps discriminators to array index
216     (lutm', lut_out)              = blockRAM lutm index discrN discrN new_discr
217     res_mem_out                   = (valid_mem!addr, (partial,addr))
218     -- Output value to the system once a discriminator is reused
219     output                        = ((complete,lut_out), new_discr && (valid_mem!discrN) == Valid)
220
221 -- ================================================================
222 -- = Controller guides correct inputs to the floating point adder =
223 -- ================================================================
224 controller :: (Cell, Cell, Cell, Cell) -> (Cell, Cell, Shift, Cell)
225 controller (inp1, inp2, pipe_out, from_res_mem) = (arg1, arg2, shift, to_res_mem)
226   where
227     (arg1, arg2, shift, to_res_mem)
228       | valid pipe_out && valid from_res_mem                          = (pipe_out, from_res_mem            , 0, notValid)
229       | valid pipe_out && valid inp1 && discr pipe_out == discr inp1  = (pipe_out, inp1                    , 1, notValid)
230       | valid inp1     && valid inp2 && discr inp1 == discr inp2      = (inp1    , inp2                    , 2, pipe_out)
231       | valid inp1                                                    = (inp1    , (Valid, (0, discr inp1)), 1, pipe_out)
232       | otherwise                                                     = (notValid, notValid                , 0, pipe_out)
233
234 -- =============================================
235 -- = Reducer: Wrap up all the above components =
236 -- =============================================
237 {-# ANN reducer TopEntity #-}
238 reducer ::  ReducerState -> (DataInt, ArrayIndex) -> (ReducerState, OutputSignal)
239 reducer (State (Reducer {..})) (data_in, index)   = (reducerState',output)
240   where
241     (discrState' , inpcell@(_,discrN), new_discr) = discriminator discrState (data_in,index)
242     (inputState' , inp1   , inp2)                 = circBuffer inputState (inpcell, shift)
243     (pipeState'  , pipe_out)                      = fpAdder pipeState (arg1, arg2)
244     (resultState', from_res_mem, output)          = resBuffO resultState (pipe_out, to_res_mem, index, (discrN, new_discr))
245     (arg1,arg2,shift,to_res_mem)                  = controller (inp1, inp2, pipe_out, from_res_mem)
246     reducerState'                                 = State ( Reducer { discrState  = discrState'
247                                                                     , inputState  = inputState'
248                                                                     , pipeState   = pipeState'
249                                                                     , resultState = resultState'
250                                                                     })
251
252 -- -------------------------------------------------------
253 -- -- Test Functions
254 -- -------------------------------------------------------            
255 --             
256 -- "Default" Run function
257 run func state [] = []
258 run func state (i:input) = o:out
259   where
260     (state', o) = func state i
261     out         = run func state' input
262
263 runReducerIO :: IO ()
264 runReducerIO = do
265   let input = siminput
266   let istate = initstate
267   let output = run reducer istate input
268   mapM_ (\x -> putStr $ ((show x) P.++ "\n")) output
269   return ()
270
271 runReducer =  ( reduceroutput
272               , validoutput
273               , equal
274               )
275   where
276     -- input = randominput 900 7
277     input  = siminput
278     istate = initstate
279     output = run reducer istate input
280     reduceroutput = P.map fst (filter (\x -> (snd x)) output)
281     validoutput   = [P.foldl (+) 0 
282                       (P.map (\z -> toInteger (fst z)) 
283                         (filter (\x -> (snd x) == i) input)) | i <- [0..30]]
284     equal = [validoutput!!i == toInteger (fst (reduceroutput!!i)) | 
285               i <- [0..30]]
286  
287 -- Generate infinite list of numbers between 1 and 'x'
288 randX :: Integer -> [Integer]   
289 randX x = randomRs (1,x) (unsafePerformIO newStdGen)
290
291 -- Generate random lists of indexes
292 randindex 15 i = randindex 1 i
293 randindex m i = (P.take n (repeat i)) P.++ (randindex (m+1) (i+1))
294   where
295     [n] = P.take 1 rnd
296     rnd = randomRs (1,m) (unsafePerformIO newStdGen)
297
298 -- Combine indexes and values to generate random input for the reducer    
299 randominput n x = P.zip data_in index_in 
300   where
301     data_in   = P.map (fromInteger :: Integer -> DataInt) (P.take n (randX x))
302     index_in  = P.map (fromInteger :: Integer -> ArrayIndex)
303                         (P.take n (randindex 7 0))
304 main = runReducerIO
305
306 initDiscrState :: DiscrRecord
307 initDiscrState = DiscrR { prev_index = (255 :: ArrayIndex)
308                         , cur_discr  = (127 :: SizedWord DiscrSize)
309                         }
310                         
311 initCircState :: CircRecord
312 initCircState = Circ { mem   = copy (0::DataInt,0::Discr)
313                      , rdptr = (14 :: RangedWord AdderDepth)
314                      , wrptr = (14 :: RangedWord AdderDepth)
315                      , count = (0 :: RangedWord (AdderDepth :+: D1))
316                      }
317                      
318 initPipeState :: Vector AdderDepth Cell
319 initPipeState = copy notValid     
320
321 initResultState :: RAM DiscrRange CellType     
322 initResultState = copy NotValid
323
324 initstate :: ReducerState
325 initstate = State ( Reducer { discrState  = State initDiscrState
326                             , inputState  = State initCircState
327                             , pipeState   = State initPipeState
328                             , resultState = State OutpO { valid_mem = initResultState 
329                                                         , mem1      = State (copy (0::DataInt))
330                                                         , mem2      = State (copy (0::DataInt))
331                                                         , lutm      = State (copy (0::ArrayIndex))
332                                                         }
333                             })
334
335 {-# ANN siminput TestInput #-}
336 siminput :: [(DataInt, ArrayIndex)]
337 siminput =  [(1,0),(5,1),(12,1),(4,2),(9,2),(2,2),(13,2),(2,2),(6,2),(1,2),(12,2),(13,3),(6,3),(11,3),(2,3),(11,3),(5,4),(11,4),(1,4),(7,4),(3,4),(4,4),(5,5),(8,5),(8,5),(13,5),(10,5),(7,5),(9,6),(9,6),(3,6),(11,6),(14,6),(13,6),(10,6),(4,7),(15,7),(13,7),(10,7),(10,7),(6,7),(15,7),(9,7),(1,7),(7,7),(15,7),(3,7),(13,7),(7,8),(3,9),(13,9),(2,10),(9,11),(10,11),(9,11),(2,11),(14,12),(14,12),(12,13),(7,13),(9,13),(7,14),(14,15),(5,16),(6,16),(14,16),(11,16),(5,16),(5,16),(7,17),(1,17),(13,17),(10,18),(15,18),(12,18),(14,19),(13,19),(2,19),(3,19),(14,19),(9,19),(11,19),(2,19),(2,20),(3,20),(13,20),(3,20),(1,20),(9,20),(10,20),(4,20),(8,21),(4,21),(8,21),(4,21),(13,21),(3,21),(7,21),(12,21),(7,21),(13,21),(3,21),(1,22),(13,23),(9,24),(14,24),(4,24),(13,25),(6,26),(12,26),(4,26),(15,26),(3,27),(6,27),(5,27),(6,27),(12,28),(2,28),(8,28),(5,29),(4,29),(1,29),(2,29),(9,29),(10,29),(4,30),(6,30),(14,30),(11,30),(15,31),(15,31),(2,31),(14,31),(9,32),(3,32),(4,32),(6,33),(15,33),(1,33),(15,33),(4,33),(3,33),(8,34),(12,34),(14,34),(15,34),(4,35),(4,35),(12,35),(14,35),(3,36),(14,37),(3,37),(1,38),(15,39),(13,39),(13,39),(1,39),(5,40),(10,40),(14,40),(1,41),(6,42),(8,42),(11,42),(11,43),(2,43),(11,43),(8,43),(12,43),(15,44),(14,44),(6,44),(8,44),(9,45),(5,45),(12,46),(6,46),(5,46),(4,46),(2,46),(9,47),(7,48),(1,48),(3,48),(10,48),(1,48),(6,48),(6,48),(11,48),(11,48),(8,48),(14,48),(5,48),(11,49),(1,49),(3,49),(11,49),(8,49),(3,50),(8,51),(9,52),(7,52),(7,53),(8,53),(10,53),(11,53),(14,54),(11,54),(4,54),(6,55),(11,55),(5,56),(7,56),(6,56),(2,56),(4,56),(12,56),(4,57),(12,57),(2,57),(14,57),(9,57),(12,57),(5,57),(11,57),(7,58),(14,58),(2,58),(10,58),(2,58),(14,58),(7,58),(12,58),(1,58),(11,59),(8,59),(2,59),(14,59),(6,59),(6,59),(6,59),(14,59),(4,59),(1,59),(4,60),(14,60),(6,60),(4,60),(8,60),(12,60),(1,60),(8,60),(8,60),(13,60),(10,61),(11,61),(6,61),(14,61),(10,61),(3,62),(10,62),(7,62),(14,62),(10,62),(4,62),(6,62),(1,62),(3,63),(3,63),(1,63),(1,63),(15,63),(7,64),(1,65),(4,65),(11,66),(3,66),(13,66),(2,67),(2,67),(5,68),(15,68),(11,68),(8,68),(4,69),(11,69),(12,69),(8,69),(7,70),(9,70),(6,70),(9,70),(11,70),(14,70),(5,71),(7,71),(11,72),(5,72),(3,72),(2,72),(1,73),(13,73),(9,73),(14,73),(5,73),(6,73),(14,73),(13,73),(3,74),(13,74),(3,75),(14,75),(10,75),(5,75),(3,75),(8,75),(9,76),(7,76),(10,76),(10,76),(8,77),(10,77),(11,77),(8,77),(2,77),(9,77),(9,77),(12,77),(4,77),(14,77),(10,77),(7,77),(3,77),(10,78),(8,79),(14,79),(11,80),(15,81),(6,81),(4,82),(6,82),(1,82),(12,83),(6,83),(11,83),(12,83),(15,83),(13,83),(1,84),(2,84),(11,84),(5,84),(2,84),(2,84),(3,84),(4,85),(6,86),(5,86),(15,86),(8,86),(9,86),(9,87),(9,87),(12,87),(4,87),(13,88),(14,88),(10,88),(11,88),(7,88),(4,88),(9,88),(1,88),(4,88),(4,88),(12,88),(8,89),(3,89),(10,89),(10,89),(5,89),(14,89),(11,89),(10,89),(5,90),(6,90),(10,90),(9,90),(8,90),(10,90),(5,90),(11,90),(6,90),(10,90),(7,90),(3,91),(7,91),(5,91),(15,91),(4,91),(6,91),(8,91),(1,91),(8,91),(12,92),(8,93),(9,93),(12,94),(8,94),(5,94),(11,95),(13,95),(5,96),(12,96),(8,96),(4,96),(7,97),(6,97),(4,97),(1,98),(5,98),(12,98),(13,99),(7,100),(12,100),(4,100),(10,100),(2,101),(3,101),(14,101),(12,101),(5,101),(2,101),(14,101),(15,101),(7,102),(13,102),(5,102),(7,102),(4,102),(8,102),(12,103),(15,103),(2,103),(2,103),(6,103),(6,103),(1,104),(14,104),(15,105),(3,105),(13,105),(1,105),(8,105),(8,105),(15,105),(13,105),(13,105),(6,105),(9,105),(6,106),(14,107),(12,107),(7,108),(7,108),(6,109),(11,109),(14,110),(8,111),(5,111),(15,111),(14,111),(3,111),(13,112),(12,112),(5,112),(10,112),(7,112),(5,113),(3,113),(2,113),(1,113),(15,113),(8,113),(10,113),(3,114),(6,114),(15,114),(4,115),(8,115),(1,115),(12,115),(5,115),(6,116),(2,116),(13,116),(12,116),(6,116),(10,117),(8,117),(14,118),(10,118),(3,118),(15,119),(6,119),(6,120),(5,121),(8,121),(4,122),(1,122),(9,123),(12,123),(6,124),(10,124),(2,124),(11,124),(9,125),(8,126),(10,126),(11,126),(14,126),(2,126),(5,126),(7,126),(3,127),(12,127),(15,128),(4,128),(1,129),(14,129),(8,129),(9,129),(6,129),(1,130),(11,130),(2,130),(13,130),(14,131),(2,131),(15,131),(4,131),(15,131),(8,131),(3,131),(8,132),(1,132),(13,132),(8,132),(5,132),(11,132),(14,132),(14,132),(4,132),(14,132),(5,132),(11,133),(1,133),(15,133),(8,133),(12,133),(8,134),(14,135),(11,136),(9,137),(3,137),(15,138),(1,138),(1,139),(4,139),(3,140),(10,140),(8,141),(12,141),(4,141),(12,141),(13,141),(10,141),(4,142),(6,142),(15,142),(4,142),(2,143),(14,143),(5,143),(10,143),(8,143),(9,143),(3,143),(11,143),(6,144),(3,145),(9,145),(10,145),(6,145),(11,145),(4,145),(13,145),(5,145),(4,145),(1,145),(3,145),(15,145),(14,146),(11,146),(9,146),(9,146),(10,146),(9,146),(3,146),(2,146),(10,146),(6,146),(7,146),(3,147),(4,147),(15,147),(11,147),(15,147),(1,147),(15,147),(14,147),(15,147),(5,147),(15,147),(4,147),(2,148),(12,149),(12,150),(10,150),(1,150),(7,151),(4,151),(14,151),(15,151),(5,152),(11,153),(3,153),(1,153),(1,153),(12,153),(1,154),(1,155),(11,155),(8,155),(3,155),(8,155),(8,155),(2,155),(9,156),(6,156),(12,156),(1,156),(3,156),(8,156),(5,157),(9,157),(12,157),(6,157),(8,158),(15,159),(2,159),(10,160),(10,160),(2,160),(6,160),(10,160),(8,160),(13,160),(12,161),(15,161),(14,161),(10,161),(13,161),(14,161),(3,161),(2,161),(1,161),(11,161),(7,161),(8,161),(4,162),(9,163),(3,164),(5,164),(9,164),(9,165),(7,165),(1,165),(6,166),(14,166),(3,166),(14,166),(4,166),(14,167),(5,167),(13,167),(12,167),(13,168),(9,168)]