Add builtin blockRAM primitive
[matthijs/master-project/cλash.git] / reducer.hs
1 {-# LANGUAGE TypeOperators, TemplateHaskell, FlexibleContexts, TypeFamilies #-}
2 module Reducer where
3
4 import System.Random
5 import System.IO.Unsafe (unsafePerformIO,unsafeInterleaveIO)
6
7 import qualified Prelude as P
8 import CLasH.HardwareTypes
9 import CLasH.Translator.Annotations
10
11 type DataSize       = D8
12 type IndexSize      = D8
13 type DiscrSize      = D3
14 type DiscrRange     = D7
15 type AdderDepth     = D2
16
17 type DataInt        = SizedWord DataSize
18 type ArrayIndex     = SizedWord IndexSize
19 type Discr          = RangedWord DiscrRange
20
21 type ReducerState   = State ( DiscrState
22                       , InputState
23                       , FpAdderState
24                       , OutputState
25                       )
26 type ReducerSignal  = ( ( DataInt
27                         , Discr
28                         )
29                       , Bit
30                       )
31                       
32 type OutputSignal   = ( (DataInt
33                         , ArrayIndex
34                         )
35                       , Bit
36                       )
37
38 type DiscrState     = State ( ArrayIndex
39                       , SizedWord DiscrSize
40                       )
41                      
42 type InputState     = State ( Vector (AdderDepth :+: D1) ReducerSignal
43                       , RangedWord AdderDepth
44                       )
45
46 type FpAdderState   = State (Vector AdderDepth ReducerSignal)
47
48 type OutputState    = State ( MemState DiscrRange DataInt
49                             , MemState DiscrRange DataInt
50                             , RAM DiscrRange ArrayIndex
51                             , RAM DiscrRange Bit
52                       )
53 {-
54 Discriminator adds a discriminator to each input value
55
56 State:
57 prev_index: previous index
58 cur_discr: current discriminator
59
60 Input:
61 data_in: input value
62 index: row index
63
64 Output:
65 data_in: output value
66 discr: discriminator belonging to output value
67 new_discr: value of new discriminator, is -1 if cur_discr hasn't changed
68 index: Index belonging to the new discriminator 
69 -}
70 discriminator ::  DiscrState -> (DataInt, ArrayIndex) -> 
71                   ( DiscrState
72                   , ((DataInt, Discr), (Bit, ArrayIndex))
73                   )
74 discriminator (State (prev_index,cur_discr)) (data_in, index) = 
75   (State (prev_index', cur_discr'), ((data_in, discr),(new_discr, index)))
76   where
77     -- Update discriminator if index changes
78     cur_discr'  | prev_index == index = cur_discr
79                 | otherwise           = cur_discr + 1
80     -- Notify OutputBuffer if a new discriminator becomes in use
81     new_discr   | prev_index == index = Low
82                 | otherwise           = High
83     prev_index'                       = index
84     discr                             = fromSizedWord cur_discr'
85
86 {-
87 Second attempt at Fifo
88 Uses "write pointer"... ugly...
89 Can potentially be mapped to hardware
90
91 State:
92 mem: content of the FIFO
93 wrptr: points to first free spot in the FIFO
94
95 Input:
96 inp: (value,discriminator) pair
97 enable: Flushes 2 values from FIFO if 2, 1 value from FIFO if 1, no values 
98         from FIFO if 0
99   
100 Output
101 out1: ((value, discriminator),valid) pair of head FIFO
102 out2: ((value, discriminator),valid) pair of second register FIFO
103
104 valid indicates if the output contains a valid discriminator
105 -}
106 inputBuffer ::  InputState -> 
107                 ((DataInt, Discr), RangedWord D2) -> 
108                 (InputState, (ReducerSignal, ReducerSignal))            
109 inputBuffer (State (mem,wrptr)) (inp,enable) = (State (mem',wrptr'),(out1, out2))
110   where
111     out1                  = last mem -- output head of FIFO
112     out2                  = last (init mem) -- output 2nd element
113     -- Update free spot pointer according to value of 'enable' 
114     wrptr'  | enable == 0 = wrptr - 1
115             | enable == 1 = wrptr
116             | otherwise   = wrptr + 1
117     -- Write value to free spot
118     mem''                 = replace mem wrptr (inp,High)
119     -- Flush values at head of fifo according to value of 'enable'
120     mem'    | enable == 0 = mem'' 
121             | enable == 1 = zero +> (init mem'')
122             | otherwise   = zero +> (zero +> (init(init mem'')))
123     zero                  = (((0::DataInt),(0::Discr)),(Low::Bit))
124             
125             
126 {-
127 floating point Adder 
128
129 output discriminator becomes discriminator of the first operant
130
131 State:
132 state: "pipeline" of the fp Adder
133
134 Input:
135 input1: out1 of the FIFO
136 input2: out2 of the FIFO
137 grant: grant signal comming from the controller, determines which value enters 
138        the pipeline
139 mem_out: Value of the output buffer for the read address
140          Read address for the output buffer is the discriminator at the top of 
141         the adder pipeline
142
143 Output:
144 output: ((Value, discriminator),valid) pair at the top of the adder pipeline
145
146 valid indicates if the output contains a valid discriminator
147 -}
148 fpAdder ::  FpAdderState -> 
149             ( ReducerSignal
150             , ReducerSignal
151             , (RangedWord D2, RangedWord D2)
152             , ReducerSignal
153             ) ->        
154             (FpAdderState, ReducerSignal)         
155 fpAdder (State state) (input1, input2, grant, mem_out) = (State state', output)
156   where
157     -- output is head of the pipeline
158     output    = last state
159     -- First value of 'grant' determines operant 1
160     operant1  | (fst grant) == 0  = fst (fst (last state))
161               | (fst grant) == 1  = fst (fst input2)
162               | otherwise         = 0
163     -- Second value of 'grant' determine operant 2
164     operant2  | (snd grant) == 0  = fst (fst input1)
165               | (snd grant) == 1  = fst (fst mem_out)
166               | (otherwise)       = 0
167     -- Determine discriminator for new value
168     discr     | (snd grant) == 0  = snd (fst input1)
169               | (snd grant) == 1  = snd (fst (last state))
170               | otherwise         = 0
171     -- Determine if discriminator should be marked as valid
172     valid     | grant == (2,2)    = Low
173               | otherwise         = High
174     -- Shift addition of the two operants into the pipeline
175     state'    = (((operant1 + operant2),discr),valid) +> (init state)
176
177 {-
178 Output logic - Determines when values are released from blockram to the output
179
180 State:
181 mem: memory belonging to the blockRAM
182 lut: Lookup table that maps discriminators to Index'
183 valid: Lookup table for 'validity' of the content of the blockRAM
184
185 Input:
186 discr: Value of the newest discriminator when it first enters the system. 
187        (-1) otherwise.
188 index: Index belonging to the newest discriminator
189 data_in: value to be written to RAM
190 rdaddr: read address
191 wraddr: write address
192 wrenable: write enabled flag
193
194 Output:
195 data_out: value of RAM at location 'rdaddr'
196 output: Reduced row when ready, (-1) otherwise
197 -}
198 outputter ::  OutputState -> 
199               ( Discr
200               , ArrayIndex
201               , Bit
202               , DataInt
203               , Discr
204               , Discr
205               , Bit
206               ) -> 
207               (OutputState, (ReducerSignal, OutputSignal))                 
208 outputter (State (mem1, mem2, lut, valid))
209   (discr, index, new_discr, data_in, rdaddr, wraddr, wrenable) = 
210   ((State (mem1', mem2', lut', valid')), (data_out, output))
211   where
212     -- Lut is updated when new discriminator/index combination enters system        
213     lut'    | new_discr /= Low  = replace lut discr index
214             | otherwise         = lut
215     -- Location becomes invalid when Reduced row leaves system        
216     valid'' | (new_discr /= Low) && ((valid!discr) /= Low) = 
217                                   replace valid discr Low
218             | otherwise         = valid
219     -- Location becomes invalid when it is fed back into the pipeline
220     valid'  | wrenable == Low   = replace valid'' rdaddr Low
221             | otherwise         = replace valid'' wraddr High
222     (mem1', mem_out1)           = blockRAM mem1 data_in rdaddr wraddr wrenable
223     (mem2', mem_out2)           = blockRAM mem2 data_in discr wraddr wrenable
224     data_out                    = ( ( (mem_out1)
225                                     , rdaddr
226                                     )
227                                   , (valid!rdaddr)
228                                   )
229     -- Reduced row is released when new discriminator enters system
230     -- And the position at the discriminator holds a valid value
231     output                      = ( ( (mem_out2)
232                                     , (lut!discr)
233                                     )
234                                   , (new_discr `hwand` (valid!discr))
235                                   )
236
237 {-
238 Arbiter determines which rules are valid
239
240 Input:
241 fp_out: output of the adder pipeline
242 mem_out: data_out of the output logic
243 inp1: Head of the input FIFO
244 inp2: Second element of input FIFO
245
246 Output:
247 r4 - r0: vector of rules, rule is invalid if it's 0, valid otherwise
248 -}
249 arbiter :: (ReducerSignal, ReducerSignal, ReducerSignal, ReducerSignal) ->  
250             Vector D5 Bit
251 arbiter (fp_out, mem_out, inp1, inp2) = (r4 +> (r3 +> (r2 +> (r1 +> (singleton r0)))))
252   where -- unpack parameters
253     fp_valid    = snd fp_out
254     next_valid  = snd mem_out
255     inp1_valid  = snd inp1
256     inp2_valid  = snd inp2
257     fp_discr    = snd (fst fp_out)
258     next_discr  = snd (fst mem_out)
259     inp1_discr  = snd (fst inp1)
260     inp2_discr  = snd (fst inp2)
261     -- Apply rules
262     r0  | (fp_valid /= Low) && (next_valid /= Low) && (fp_discr == next_discr)  
263                                       = High
264         | otherwise                   = Low
265     r1  | (fp_valid /= Low) && (inp1_valid /= Low) && (fp_discr == inp1_discr)  
266                                       = High
267         | otherwise                   = Low
268     r2  | (inp1_valid /= Low) && (inp2_valid /= Low) && 
269           (inp1_discr == inp2_discr)  = High
270         | otherwise                   = Low
271     r3  | inp1_valid /= Low           = High
272         | otherwise                   = Low
273     r4                                = High
274
275 {-
276 Controller determines which values are fed into the pipeline
277 and if the write enable flag for the Output RAM should be set
278 to true. Also determines how many values should be flushed from
279 the input FIFO.
280
281 Input:
282 fp_out: output of the adder pipeline
283 mem_out: data_out of the output logic
284 inp1: Head of input FIFO
285 inp2: Second element of input FIFO
286
287 Output:
288 grant: Signal that determines operants for the adder
289 enable: Number of values to be flushed from input buffer
290 wr_enable: Determine if value of the adder should be written to RAM
291 -}
292 controller :: (ReducerSignal, ReducerSignal, ReducerSignal, ReducerSignal) -> 
293                 ((RangedWord D2, RangedWord D2), RangedWord D2, Bit)
294 controller (fp_out,mem_out,inp1,inp2) = (grant,enable,wr_enable)
295   where
296     -- Arbiter determines which rules are valid
297     valid       = arbiter (fp_out,mem_out,inp1,inp2)
298     -- Determine which values should be fed to the adder
299     grant       = if (valid!(4 :: RangedWord D4) == High) 
300                   then (0,1) 
301                   else if ((drop d3 valid) == $(vectorTH [High,Low])) 
302                   then (0,0) 
303                   else if ((drop d2 valid) == $(vectorTH [High,Low,Low])) 
304                   then (1,0) 
305                   else if ((drop d1 valid) == $(vectorTH [High,Low,Low,Low])) 
306                   then (2,0) 
307                   else (2,2)
308     -- Determine if some values should be flushed from input FIFO
309     enable      = if (grant == (1,0)) 
310                   then 2 
311                   else if ((grant == (0,0)) || (grant == (2,0))) 
312                   then 1 
313                   else 0
314     -- Determine if the output value of the adder should be written to RAM
315     wr_enable'  = if (valid!(4 :: RangedWord D4) == High) 
316                   then Low 
317                   else if ((drop d3 valid) == $(vectorTH [High,Low])) 
318                   then Low 
319                   else if ((drop d2 valid) == $(vectorTH [High,Low,Low]))
320                   then High
321                   else if ((drop d1 valid) == $(vectorTH [High,Low,Low,Low])) 
322                   then High 
323                   else High
324     wr_enable   = if ((snd fp_out) /= Low) then wr_enable' else Low
325
326 {-
327 Reducer
328
329 Combines all the earlier defined functions. Uses the second implementation
330 of the input FIFO.
331
332 Parameter: 
333 'n': specifies the max discriminator value.
334
335 State: all the states of the used functions
336
337 Input: (value,index) combination
338
339 Output: reduced row
340 -}
341 {-# ANN reducer TopEntity #-}
342 reducer ::  ReducerState -> 
343             (DataInt, ArrayIndex) -> 
344             (ReducerState, OutputSignal)
345 reducer (State (discrstate,inputstate,fpadderstate,outputstate)) input = 
346   (State (discrstate',inputstate',fpadderstate',outputstate'),output)
347   where
348     (discrstate', discr_out)              = discriminator discrstate input
349     (inputstate',(fifo_out1, fifo_out2))  = inputBuffer inputstate (
350                                             (fst discr_out), enable)
351     (fpadderstate', fp_out)               = fpAdder fpadderstate (fifo_out1, 
352                                                 fifo_out2, grant, mem_out)
353     discr                                 = snd (fst discr_out)
354     new_discr                             = fst (snd discr_out)
355     index                                 = snd (snd discr_out)
356     rdaddr                                = snd (fst fp_out)
357     wraddr                                = rdaddr
358     data_in                               = fst (fst fp_out)
359     (outputstate', (mem_out, output))     = outputter outputstate (discr, 
360                                             index, new_discr, data_in, rdaddr, 
361                                             wraddr, wr_enable)
362     (grant,enable,wr_enable)              = controller (fp_out, mem_out, 
363                                             fifo_out1, fifo_out2)
364
365
366 -- -------------------------------------------------------
367 -- -- Test Functions
368 -- -------------------------------------------------------            
369 --             
370 -- "Default" Run function
371 run func state [] = []
372 run func state (i:input) = o:out
373   where
374     (state', o) = func state i
375     out         = run func state' input
376 -- 
377 -- -- "Special" Run function, also outputs new state      
378 -- run' func state [] = ([],[])   
379 -- run' func state (i:input) = ((o:out), (state':ss))
380 --   where
381 --     (state',o)  = func state i
382 --     (out,ss)         = run' func state' input
383 -- 
384 -- Run reducer
385 runReducer =  ( reduceroutput
386               , validoutput
387               , equal
388               )
389   where
390     input = siminput
391     istate = initstate
392     output = run reducer istate input
393     reduceroutput = P.map fst (filter (\x -> (snd x) /= Low) output)
394     validoutput   = [P.foldl (+) 0 
395                       (P.map (\z -> toInteger (fst z)) 
396                         (filter (\x -> (snd x) == i) input)) | i <- [0..10]]
397     equal = [validoutput!!i == toInteger (fst (reduceroutput!!i)) | 
398               i <- [0..10]]
399 -- 
400 -- -- Generate infinite list of numbers between 1 and 'x'
401 -- randX :: Integer -> [Integer]   
402 -- randX x = randomRs (1,x) (unsafePerformIO newStdGen)
403 -- 
404 -- -- Generate random lists of indexes
405 -- randindex 15 i = randindex 1 i
406 -- randindex m i = (P.take n (repeat i)) P.++ (randindex (m+1) (i+1))
407 --   where
408 --     [n] = P.take 1 rnd
409 --     rnd = randomRs (1,m) (unsafePerformIO newStdGen)
410 -- 
411 -- -- Combine indexes and values to generate random input for the reducer    
412 -- randominput n x = P.zip data_in index_in 
413 --   where
414 --     data_in   = P.map (fromInteger :: Integer -> DataInt) (P.take n (randX x))
415 --     index_in  = P.map (fromInteger :: Integer -> ArrayIndex)
416 --                         (P.take n (randindex 7 0))
417 -- main = 
418 --   do
419 --     putStrLn (show runReducer)
420
421 -- simulate f input s = do
422 --   putStr "Input: "
423 --   putStr $ show input
424 --   putStr "\nInitial State: "
425 --   putStr $ show s
426 --   putStr "\n\n"
427 --   foldl1 (>>) (map (printOutput) output)
428 --   where
429 --     output = run f input s
430
431 initstate :: ReducerState
432 initstate = State
433   ( State ( (255 :: ArrayIndex)
434     , (7 :: SizedWord DiscrSize)
435     )
436   , State ( copy ((0::DataInt,0::Discr),Low)
437     , (2 :: RangedWord AdderDepth)
438     )
439   , State (copy ((0::DataInt,0::Discr),Low))
440   , State ( State (copy (0::DataInt))
441           , State (copy (0::DataInt))
442           , (copy (0::ArrayIndex))
443           , (copy Low)
444           )
445   )
446
447 siminput :: [(DataInt, ArrayIndex)]
448 siminput =  [(13,0),(7,0),(14,0),(14,0),(12,0),(10,0),(19,1),(20,1),(13,1)
449             ,(5,1),(9,1),(16,1),(15,1),(10,2),(13,2),(3,2),(9,2),(19,2),(5,3)
450             ,(5,3),(10,3),(17,3),(14,3),(5,3),(15,3),(11,3),(5,3),(1,3),(8,4)
451             ,(20,4),(8,4),(1,4),(11,4),(10,4),(13,5),(18,5),(5,5),(6,5),(6,5)
452             ,(4,6),(4,6),(11,6),(11,6),(11,6),(1,6),(11,6),(3,6),(12,6),(12,6)
453             ,(2,6),(14,6),(11,7),(13,7),(17,7),(9,7),(19,8),(4,9),(18,10)
454             ,(6,10),(18,11),(1,12),(3,12),(14,12),(18,12),(14,12),(6,13)
455             ,(9,13),(11,14),(4,14),(1,14),(14,14),(14,14),(6,14),(11,15)
456             ,(13,15),(7,15),(2,16),(16,16),(17,16),(5,16),(20,16),(17,16)
457             ,(14,16),(18,17),(13,17),(1,17),(19,18),(1,18),(20,18),(4,18)
458             ,(5,19),(4,19),(6,19),(19,19),(4,19),(3,19),(7,19),(13,19),(19,19)
459             ,(8,19)
460             ]