Add boolean or and and, tuple fst and snd function.
[matthijs/master-project/cλash.git] / reducer.hs
1 {-# LANGUAGE TypeOperators, TemplateHaskell #-}
2 module Reducer where
3
4 import System.Random
5 import System.IO.Unsafe (unsafePerformIO,unsafeInterleaveIO)
6
7 import qualified Prelude as P
8 import CLasH.HardwareTypes
9 import CLasH.Translator.Annotations
10
11 type DataSize       = D8
12 type IndexSize      = D8
13 type DiscrSize      = D3
14 type DiscrRange     = D7
15 type AdderDepth     = D2
16
17 type DataInt        = SizedWord DataSize
18 type ArrayIndex     = SizedWord IndexSize
19 type Discr          = RangedWord DiscrRange
20
21 type RAM a          = Vector (DiscrRange :+: D1) a
22
23 type ReducerState   = State ( DiscrState
24                       , InputState
25                       , FpAdderState
26                       , OutputState
27                       )
28 type ReducerSignal  = ( ( DataInt
29                         , Discr
30                         )
31                       , Bit
32                       )
33
34 type MemState a      = State (RAM a)
35                       
36 type OutputSignal   = ( (DataInt
37                         , ArrayIndex
38                         )
39                       , Bit
40                       )
41
42 type DiscrState     = State ( ArrayIndex
43                       , SizedWord DiscrSize
44                       )
45                      
46 type InputState     = State ( Vector (AdderDepth :+: D1) ReducerSignal
47                       , RangedWord AdderDepth
48                       )
49
50 type FpAdderState   = State (Vector AdderDepth ReducerSignal)
51
52 type OutputState    = State ( MemState DataInt
53                             , RAM ArrayIndex
54                             , RAM Bit
55                       )
56 {-
57 Discriminator adds a discriminator to each input value
58
59 State:
60 prev_index: previous index
61 cur_discr: current discriminator
62
63 Input:
64 data_in: input value
65 index: row index
66
67 Output:
68 data_in: output value
69 discr: discriminator belonging to output value
70 new_discr: value of new discriminator, is -1 if cur_discr hasn't changed
71 index: Index belonging to the new discriminator 
72 -}
73 discriminator ::  DiscrState -> (DataInt, ArrayIndex) -> 
74                   ( DiscrState
75                   , ((DataInt, Discr), (Bit, ArrayIndex))
76                   )
77 discriminator (State (prev_index,cur_discr)) (data_in, index) = 
78   (State (prev_index', cur_discr'), ((data_in, discr),(new_discr, index)))
79   where
80     -- Update discriminator if index changes
81     cur_discr'  | prev_index == index = cur_discr
82                 | otherwise           = cur_discr + 1
83     -- Notify OutputBuffer if a new discriminator becomes in use
84     new_discr   | prev_index == index = Low
85                 | otherwise           = High
86     prev_index'                       = index
87     discr                             = fromSizedWord cur_discr'
88
89 {-
90 Second attempt at Fifo
91 Uses "write pointer"... ugly...
92 Can potentially be mapped to hardware
93
94 State:
95 mem: content of the FIFO
96 wrptr: points to first free spot in the FIFO
97
98 Input:
99 inp: (value,discriminator) pair
100 enable: Flushes 2 values from FIFO if 2, 1 value from FIFO if 1, no values 
101         from FIFO if 0
102   
103 Output
104 out1: ((value, discriminator),valid) pair of head FIFO
105 out2: ((value, discriminator),valid) pair of second register FIFO
106
107 valid indicates if the output contains a valid discriminator
108 -}
109 inputBuffer ::  InputState -> 
110                 ((DataInt, Discr), RangedWord D2) -> 
111                 (InputState, (ReducerSignal, ReducerSignal))            
112 inputBuffer (State (mem,wrptr)) (inp,enable) = (State (mem',wrptr'),(out1, out2))
113   where
114     out1                  = last mem -- output head of FIFO
115     out2                  = last (init mem) -- output 2nd element
116     -- Update free spot pointer according to value of 'enable' 
117     wrptr'  | enable == 0 = wrptr - 1
118             | enable == 1 = wrptr
119             | otherwise   = wrptr + 1
120     -- Write value to free spot
121     mem''                 = replace mem wrptr (inp,High)
122     -- Flush values at head of fifo according to value of 'enable'
123     mem'    | enable == 0 = mem'' 
124             | enable == 1 = zero +> (init mem'')
125             | otherwise   = zero +> (zero +> (init(init mem'')))
126     zero                  = (((0::DataInt),(0::Discr)),(Low::Bit))
127             
128             
129 {-
130 floating point Adder 
131
132 output discriminator becomes discriminator of the first operant
133
134 State:
135 state: "pipeline" of the fp Adder
136
137 Input:
138 input1: out1 of the FIFO
139 input2: out2 of the FIFO
140 grant: grant signal comming from the controller, determines which value enters 
141        the pipeline
142 mem_out: Value of the output buffer for the read address
143          Read address for the output buffer is the discriminator at the top of 
144         the adder pipeline
145
146 Output:
147 output: ((Value, discriminator),valid) pair at the top of the adder pipeline
148
149 valid indicates if the output contains a valid discriminator
150 -}
151 fpAdder ::  FpAdderState -> 
152             ( ReducerSignal
153             , ReducerSignal
154             , (RangedWord D2, RangedWord D2)
155             , ReducerSignal
156             ) ->        
157             (FpAdderState, ReducerSignal)         
158 fpAdder (State state) (input1, input2, grant, mem_out) = (State state', output)
159   where
160     -- output is head of the pipeline
161     output    = last state
162     -- First value of 'grant' determines operant 1
163     operant1  | (fst grant) == 0  = fst (fst (last state))
164               | (fst grant) == 1  = fst (fst input2)
165               | otherwise         = 0
166     -- Second value of 'grant' determine operant 2
167     operant2  | (snd grant) == 0  = fst (fst input1)
168               | (snd grant) == 1  = fst (fst mem_out)
169               | (otherwise)       = 0
170     -- Determine discriminator for new value
171     discr     | (snd grant) == 0  = snd (fst input1)
172               | (snd grant) == 1  = snd (fst (last state))
173               | otherwise         = 0
174     -- Determine if discriminator should be marked as valid
175     valid     | grant == (2,2)    = Low
176               | otherwise         = High
177     -- Shift addition of the two operants into the pipeline
178     state'    = (((operant1 + operant2),discr),valid) +> (init state)
179     
180
181 {- 
182 first attempt at BlockRAM
183
184 State:
185 mem: content of the RAM
186
187 Input:
188 data_in: input value to be written to 'mem' at location 'wraddr'
189 rdaddr: read address
190 wraddr: write address
191 wrenable: write enable flag
192
193 Output:
194 data_out: value of 'mem' at location 'rdaddr'
195 -}
196 blockRAM :: (MemState a) -> 
197             ( a
198             , Discr
199             , Discr
200             , Discr
201             , Bit
202             ) -> 
203             (MemState a, (a, a) )
204 blockRAM (State mem) (data_in, rdaddr1, rdaddr2, wraddr, wrenable) = 
205   ((State mem'), (data_out1,data_out2))
206   where
207     data_out1               = mem!rdaddr1 
208     data_out2               = mem!rdaddr2
209     -- Only write data_in to memory if write is enabled
210     mem'  | wrenable == Low = mem
211           | otherwise       = replace mem wraddr data_in
212
213 {-
214 Output logic - Determines when values are released from blockram to the output
215
216 State:
217 mem: memory belonging to the blockRAM
218 lut: Lookup table that maps discriminators to Index'
219 valid: Lookup table for 'validity' of the content of the blockRAM
220
221 Input:
222 discr: Value of the newest discriminator when it first enters the system. 
223        (-1) otherwise.
224 index: Index belonging to the newest discriminator
225 data_in: value to be written to RAM
226 rdaddr: read address
227 wraddr: write address
228 wrenable: write enabled flag
229
230 Output:
231 data_out: value of RAM at location 'rdaddr'
232 output: Reduced row when ready, (-1) otherwise
233 -}
234 outputter ::  OutputState -> 
235               ( Discr
236               , ArrayIndex
237               , Bit
238               , DataInt
239               , Discr
240               , Discr
241               , Bit
242               ) -> 
243               (OutputState, (ReducerSignal, OutputSignal))                 
244 outputter (State (mem,  lut, valid))
245   (discr, index, new_discr, data_in, rdaddr, wraddr, wrenable) = 
246   ((State (mem', lut', valid')), (data_out, output))
247   where
248     -- Lut is updated when new discriminator/index combination enters system        
249     lut'    | new_discr /= Low  = replace lut discr index
250             | otherwise         = lut
251     -- Location becomes invalid when Reduced row leaves system        
252     valid'' | (new_discr /= Low) && ((valid!discr) /= Low) = 
253                                   replace valid discr Low
254             | otherwise         = valid
255     -- Location becomes invalid when it is fed back into the pipeline
256     valid'  | wrenable == Low   = replace valid'' rdaddr Low
257             | otherwise         = replace valid'' wraddr High
258     (mem', mem_out)             = blockRAM mem ( data_in
259                                                 , rdaddr
260                                                 , discr
261                                                 , wraddr
262                                                 , wrenable
263                                                 )
264     data_out                    = ( ( (fst mem_out)
265                                     , rdaddr
266                                     )
267                                   , (valid!rdaddr)
268                                   )
269     -- Reduced row is released when new discriminator enters system
270     -- And the position at the discriminator holds a valid value
271     output                      = ( ( (snd mem_out)
272                                     , (lut!discr)
273                                     )
274                                   , (new_discr `hwand` (valid!discr))
275                                   )
276
277 {-
278 Arbiter determines which rules are valid
279
280 Input:
281 fp_out: output of the adder pipeline
282 mem_out: data_out of the output logic
283 inp1: Head of the input FIFO
284 inp2: Second element of input FIFO
285
286 Output:
287 r4 - r0: vector of rules, rule is invalid if it's 0, valid otherwise
288 -}
289 arbiter :: (ReducerSignal, ReducerSignal, ReducerSignal, ReducerSignal) ->  
290             Vector D5 Bit
291 arbiter (fp_out, mem_out, inp1, inp2) = (r4 +> (r3 +> (r2 +> (r1 +> (singleton r0)))))
292   where -- unpack parameters
293     fp_valid    = snd fp_out
294     next_valid  = snd mem_out
295     inp1_valid  = snd inp1
296     inp2_valid  = snd inp2
297     fp_discr    = snd (fst fp_out)
298     next_discr  = snd (fst mem_out)
299     inp1_discr  = snd (fst inp1)
300     inp2_discr  = snd (fst inp2)
301     -- Apply rules
302     r0  | (fp_valid /= Low) && (next_valid /= Low) && (fp_discr == next_discr)  
303                                       = High
304         | otherwise                   = Low
305     r1  | (fp_valid /= Low) && (inp1_valid /= Low) && (fp_discr == inp1_discr)  
306                                       = High
307         | otherwise                   = Low
308     r2  | (inp1_valid /= Low) && (inp2_valid /= Low) && 
309           (inp1_discr == inp2_discr)  = High
310         | otherwise                   = Low
311     r3  | inp1_valid /= Low           = High
312         | otherwise                   = Low
313     r4                                = High
314
315 {-
316 Controller determines which values are fed into the pipeline
317 and if the write enable flag for the Output RAM should be set
318 to true. Also determines how many values should be flushed from
319 the input FIFO.
320
321 Input:
322 fp_out: output of the adder pipeline
323 mem_out: data_out of the output logic
324 inp1: Head of input FIFO
325 inp2: Second element of input FIFO
326
327 Output:
328 grant: Signal that determines operants for the adder
329 enable: Number of values to be flushed from input buffer
330 wr_enable: Determine if value of the adder should be written to RAM
331 -}
332 controller :: (ReducerSignal, ReducerSignal, ReducerSignal, ReducerSignal) -> 
333                 ((RangedWord D2, RangedWord D2), RangedWord D2, Bit)
334 controller (fp_out,mem_out,inp1,inp2) = (grant,enable,wr_enable)
335   where
336     -- Arbiter determines which rules are valid
337     valid       = arbiter (fp_out,mem_out,inp1,inp2)
338     -- Determine which values should be fed to the adder
339     grant       = if (valid!(4 :: RangedWord D4) == High) 
340                   then (0,1) 
341                   else if ((drop d3 valid) == $(vectorTH [High,Low])) 
342                   then (0,0) 
343                   else if ((drop d2 valid) == $(vectorTH [High,Low,Low])) 
344                   then (1,0) 
345                   else if ((drop d1 valid) == $(vectorTH [High,Low,Low,Low])) 
346                   then (2,0) 
347                   else (2,2)
348     -- Determine if some values should be flushed from input FIFO
349     enable      = if (grant == (1,0)) 
350                   then 2 
351                   else if ((grant == (0,0)) || (grant == (2,0))) 
352                   then 1 
353                   else 0
354     -- Determine if the output value of the adder should be written to RAM
355     wr_enable'  = if (valid!(4 :: RangedWord D4) == High) 
356                   then Low 
357                   else if ((drop d3 valid) == $(vectorTH [High,Low])) 
358                   then Low 
359                   else if ((drop d2 valid) == $(vectorTH [High,Low,Low]))
360                   then High
361                   else if ((drop d1 valid) == $(vectorTH [High,Low,Low,Low])) 
362                   then High 
363                   else High
364     wr_enable   = if ((snd fp_out) /= Low) then wr_enable' else Low
365
366 {-
367 Reducer
368
369 Combines all the earlier defined functions. Uses the second implementation
370 of the input FIFO.
371
372 Parameter: 
373 'n': specifies the max discriminator value.
374
375 State: all the states of the used functions
376
377 Input: (value,index) combination
378
379 Output: reduced row
380 -}
381 {-# ANN reducer TopEntity #-}
382 reducer ::  ReducerState -> 
383             (DataInt, ArrayIndex) -> 
384             (ReducerState, OutputSignal)
385 reducer (State (discrstate,inputstate,fpadderstate,outputstate)) input = 
386   (State (discrstate',inputstate',fpadderstate',outputstate'),output)
387   where
388     (discrstate', discr_out)              = discriminator discrstate input
389     (inputstate',(fifo_out1, fifo_out2))  = inputBuffer inputstate (
390                                             (fst discr_out), enable)
391     (fpadderstate', fp_out)               = fpAdder fpadderstate (fifo_out1, 
392                                                 fifo_out2, grant, mem_out)
393     discr                                 = snd (fst discr_out)
394     new_discr                             = fst (snd discr_out)
395     index                                 = snd (snd discr_out)
396     rdaddr                                = snd (fst fp_out)
397     wraddr                                = rdaddr
398     data_in                               = fst (fst fp_out)
399     (outputstate', (mem_out, output))     = outputter outputstate (discr, 
400                                             index, new_discr, data_in, rdaddr, 
401                                             wraddr, wr_enable)
402     (grant,enable,wr_enable)              = controller (fp_out, mem_out, 
403                                             fifo_out1, fifo_out2)
404
405
406 -- -------------------------------------------------------
407 -- -- Test Functions
408 -- -------------------------------------------------------            
409 --             
410 -- "Default" Run function
411 run func state [] = []
412 run func state (i:input) = o:out
413   where
414     (state', o) = func state i
415     out         = run func state' input
416 -- 
417 -- -- "Special" Run function, also outputs new state      
418 -- run' func state [] = ([],[])   
419 -- run' func state (i:input) = ((o:out), (state':ss))
420 --   where
421 --     (state',o)  = func state i
422 --     (out,ss)         = run' func state' input
423 -- 
424 -- Run reducer
425 runReducer =  ( reduceroutput
426               , validoutput
427               , equal
428               )
429   where
430     input = siminput
431     istate = initstate
432     output = run reducer istate input
433     reduceroutput = P.map fst (filter (\x -> (snd x) /= Low) output)
434     validoutput   = [P.foldl (+) 0 
435                       (P.map (\z -> toInteger (fst z)) 
436                         (filter (\x -> (snd x) == i) input)) | i <- [0..10]]
437     equal = [validoutput!!i == toInteger (fst (reduceroutput!!i)) | 
438               i <- [0..10]]
439 -- 
440 -- -- Generate infinite list of numbers between 1 and 'x'
441 -- randX :: Integer -> [Integer]   
442 -- randX x = randomRs (1,x) (unsafePerformIO newStdGen)
443 -- 
444 -- -- Generate random lists of indexes
445 -- randindex 15 i = randindex 1 i
446 -- randindex m i = (P.take n (repeat i)) P.++ (randindex (m+1) (i+1))
447 --   where
448 --     [n] = P.take 1 rnd
449 --     rnd = randomRs (1,m) (unsafePerformIO newStdGen)
450 -- 
451 -- -- Combine indexes and values to generate random input for the reducer    
452 -- randominput n x = P.zip data_in index_in 
453 --   where
454 --     data_in   = P.map (fromInteger :: Integer -> DataInt) (P.take n (randX x))
455 --     index_in  = P.map (fromInteger :: Integer -> ArrayIndex)
456 --                         (P.take n (randindex 7 0))
457 -- main = 
458 --   do
459 --     putStrLn (show runReducer)
460
461 -- simulate f input s = do
462 --   putStr "Input: "
463 --   putStr $ show input
464 --   putStr "\nInitial State: "
465 --   putStr $ show s
466 --   putStr "\n\n"
467 --   foldl1 (>>) (map (printOutput) output)
468 --   where
469 --     output = run f input s
470
471 initstate :: ReducerState
472 initstate = State
473   ( State ( (255 :: ArrayIndex)
474     , (7 :: SizedWord DiscrSize)
475     )
476   , State ( copy ((0::DataInt,0::Discr),Low)
477     , (2 :: RangedWord AdderDepth)
478     )
479   , State (copy ((0::DataInt,0::Discr),Low))
480   , State ( State (copy (0::DataInt))
481           , (copy (0::ArrayIndex))
482           , (copy Low)
483           )
484   )
485
486 siminput :: [(DataInt, ArrayIndex)]
487 siminput =  [(13,0),(7,0),(14,0),(14,0),(12,0),(10,0),(19,1),(20,1),(13,1)
488             ,(5,1),(9,1),(16,1),(15,1),(10,2),(13,2),(3,2),(9,2),(19,2),(5,3)
489             ,(5,3),(10,3),(17,3),(14,3),(5,3),(15,3),(11,3),(5,3),(1,3),(8,4)
490             ,(20,4),(8,4),(1,4),(11,4),(10,4),(13,5),(18,5),(5,5),(6,5),(6,5)
491             ,(4,6),(4,6),(11,6),(11,6),(11,6),(1,6),(11,6),(3,6),(12,6),(12,6)
492             ,(2,6),(14,6),(11,7),(13,7),(17,7),(9,7),(19,8),(4,9),(18,10)
493             ,(6,10),(18,11),(1,12),(3,12),(14,12),(18,12),(14,12),(6,13)
494             ,(9,13),(11,14),(4,14),(1,14),(14,14),(14,14),(6,14),(11,15)
495             ,(13,15),(7,15),(2,16),(16,16),(17,16),(5,16),(20,16),(17,16)
496             ,(14,16),(18,17),(13,17),(1,17),(19,18),(1,18),(20,18),(4,18)
497             ,(5,19),(4,19),(6,19),(19,19),(4,19),(3,19),(7,19),(13,19),(19,19)
498             ,(8,19)
499             ]