Merge branch 'cλash' of http://git.stderr.nl/matthijs/projects/master-project
[matthijs/master-project/cλash.git] / reducer.hs
1 {-# LANGUAGE TypeOperators, TemplateHaskell, FlexibleContexts, TypeFamilies #-}
2 module Reducer where
3
4 import System.Random
5 import System.IO.Unsafe (unsafePerformIO,unsafeInterleaveIO)
6
7 import qualified Prelude as P
8 import CLasH.HardwareTypes
9 import CLasH.Translator.Annotations
10
11 type DataSize       = D8
12 type IndexSize      = D8
13 type DiscrSize      = D3
14 type DiscrRange     = D7
15 type AdderDepth     = D2
16
17 type DataInt        = SizedWord DataSize
18 type ArrayIndex     = SizedWord IndexSize
19 type Discr          = RangedWord DiscrRange
20
21 type ReducerState   = State ( DiscrState
22                       , InputState
23                       , FpAdderState
24                       , OutputState
25                       )
26 type ReducerSignal  = ( ( DataInt
27                         , Discr
28                         )
29                       , Bit
30                       )
31                       
32 type OutputSignal   = ( (DataInt
33                         , ArrayIndex
34                         )
35                       , Bit
36                       )
37
38 type DiscrState     = State ( ArrayIndex
39                       , SizedWord DiscrSize
40                       )
41                      
42 type InputState     = State ( Vector (AdderDepth :+: D1) ReducerSignal
43                       , RangedWord AdderDepth
44                       )
45
46 type FpAdderState   = State (Vector AdderDepth ReducerSignal)
47
48 type OutputState    = State ( MemState DiscrRange DataInt
49                             , MemState DiscrRange DataInt
50                             , MemState DiscrRange ArrayIndex
51                             , RAM DiscrRange Bit
52                       )
53 {-
54 Discriminator adds a discriminator to each input value
55
56 State:
57 prev_index: previous index
58 cur_discr: current discriminator
59
60 Input:
61 data_in: input value
62 index: row index
63
64 Output:
65 data_in: output value
66 discr: discriminator belonging to output value
67 new_discr: value of new discriminator, is -1 if cur_discr hasn't changed
68 index: Index belonging to the new discriminator 
69 -}
70 discriminator ::  DiscrState -> (DataInt, ArrayIndex) -> 
71                   ( DiscrState
72                   , ((DataInt, Discr), (Bit, ArrayIndex))
73                   )
74 discriminator (State (prev_index,cur_discr)) (data_in, index) = 
75   (State (prev_index', cur_discr'), ((data_in, discr),(new_discr, index)))
76   where
77     -- Update discriminator if index changes
78     cur_discr'  | prev_index == index = cur_discr
79                 | otherwise           = cur_discr + 1
80     -- Notify OutputBuffer if a new discriminator becomes in use
81     new_discr   | prev_index == index = Low
82                 | otherwise           = High
83     prev_index'                       = index
84     discr                             = fromSizedWord cur_discr'
85
86 {-
87 Second attempt at Fifo
88 Uses "write pointer"... ugly...
89 Can potentially be mapped to hardware
90
91 State:
92 mem: content of the FIFO
93 wrptr: points to first free spot in the FIFO
94
95 Input:
96 inp: (value,discriminator) pair
97 enable: Flushes 2 values from FIFO if 2, 1 value from FIFO if 1, no values 
98         from FIFO if 0
99   
100 Output
101 out1: ((value, discriminator),valid) pair of head FIFO
102 out2: ((value, discriminator),valid) pair of second register FIFO
103
104 valid indicates if the output contains a valid discriminator
105 -}
106 inputBuffer ::  InputState -> 
107                 ((DataInt, Discr), RangedWord D2) -> 
108                 (InputState, ReducerSignal, ReducerSignal)
109 inputBuffer (State (mem,wrptr)) (inp,enable) = (State (mem',wrptr'),out1, out2)
110   where
111     out1                  = last mem -- output head of FIFO
112     out2                  = last (init mem) -- output 2nd element
113     -- Update free spot pointer according to value of 'enable' 
114     wrptr'  | enable == 0 = wrptr - 1
115             | enable == 1 = wrptr
116             | otherwise   = wrptr + 1
117     -- Write value to free spot
118     mem''                 = replace mem wrptr (inp,High)
119     -- Flush values at head of fifo according to value of 'enable'
120     mem'    | enable == 0 = mem'' 
121             | enable == 1 = zero +> (init mem'')
122             | otherwise   = zero +> (zero +> (init(init mem'')))
123     zero                  = (((0::DataInt),(0::Discr)),(Low::Bit))
124             
125             
126 {-
127 floating point Adder 
128
129 output discriminator becomes discriminator of the first operant
130
131 State:
132 state: "pipeline" of the fp Adder
133
134 Input:
135 input1: out1 of the FIFO
136 input2: out2 of the FIFO
137 grant: grant signal comming from the controller, determines which value enters 
138        the pipeline
139 mem_out: Value of the output buffer for the read address
140          Read address for the output buffer is the discriminator at the top of 
141         the adder pipeline
142
143 Output:
144 output: ((Value, discriminator),valid) pair at the top of the adder pipeline
145
146 valid indicates if the output contains a valid discriminator
147 -}
148 fpAdder ::  FpAdderState -> 
149             ( ReducerSignal
150             , ReducerSignal
151             , (RangedWord D2, RangedWord D2)
152             , ReducerSignal
153             ) ->        
154             (FpAdderState, ReducerSignal)         
155 fpAdder (State state) (input1, input2, grant, mem_out) = (State state', output)
156   where
157     -- output is head of the pipeline
158     output    = last state
159     -- First value of 'grant' determines operant 1
160     operant1  | (fst grant) == 0  = fst (fst (last state))
161               | (fst grant) == 1  = fst (fst input2)
162               | otherwise         = 0
163     -- Second value of 'grant' determine operant 2
164     operant2  | (snd grant) == 0  = fst (fst input1)
165               | (snd grant) == 1  = fst (fst mem_out)
166               | (otherwise)       = 0
167     -- Determine discriminator for new value
168     discr     | (snd grant) == 0  = snd (fst input1)
169               | (snd grant) == 1  = snd (fst (last state))
170               | otherwise         = 0
171     -- Determine if discriminator should be marked as valid
172     valid     | grant == (2,2)    = Low
173               | otherwise         = High
174     -- Shift addition of the two operants into the pipeline
175     state'    = (((operant1 + operant2),discr),valid) +> (init state)
176
177 {-
178 Output logic - Determines when values are released from blockram to the output
179
180 State:
181 mem: memory belonging to the blockRAM
182 lut: Lookup table that maps discriminators to Index'
183 valid: Lookup table for 'validity' of the content of the blockRAM
184
185 Input:
186 discr: Value of the newest discriminator when it first enters the system. 
187        (-1) otherwise.
188 index: Index belonging to the newest discriminator
189 data_in: value to be written to RAM
190 rdaddr: read address
191 wraddr: write address
192 wrenable: write enabled flag
193
194 Output:
195 data_out: value of RAM at location 'rdaddr'
196 output: Reduced row when ready, (-1) otherwise
197 -}
198 outputter ::  OutputState -> 
199               ( Discr
200               , ArrayIndex
201               , Bit
202               , DataInt
203               , Discr
204               , Discr
205               , Bit
206               ) -> 
207               (OutputState, ReducerSignal, OutputSignal)
208 outputter (State (mem1, mem2, lut, valid))
209   (discr, index, new_discr, data_in, rdaddr, wraddr, wrenable) = 
210   ((State (mem1', mem2', lut', valid')), data_out, output)
211   where
212     -- Lut is updated when new discriminator/index combination enters system        
213     (lut', lut_out)             = blockRAM lut index discr discr new_discr
214     -- Location becomes invalid when Reduced row leaves system        
215     valid'' | (new_discr /= Low) && ((valid!discr) /= Low) = 
216                                   replace valid discr Low
217             | otherwise         = valid
218     -- Location becomes invalid when it is fed back into the pipeline
219     valid'  | wrenable == Low   = replace valid'' rdaddr Low
220             | otherwise         = replace valid'' wraddr High
221     (mem1', mem_out1)           = blockRAM mem1 data_in rdaddr wraddr wrenable
222     (mem2', mem_out2)           = blockRAM mem2 data_in discr wraddr wrenable
223     data_out                    = ( ( (mem_out1)
224                                     , rdaddr
225                                     )
226                                   , (valid!rdaddr)
227                                   )
228     -- Reduced row is released when new discriminator enters system
229     -- And the position at the discriminator holds a valid value
230     output                      = ( ( (mem_out2)
231                                     , (lut_out)
232                                     )
233                                   , (new_discr `hwand` (valid!discr))
234                                   )
235
236 {-
237 Arbiter determines which rules are valid
238
239 Input:
240 fp_out: output of the adder pipeline
241 mem_out: data_out of the output logic
242 inp1: Head of the input FIFO
243 inp2: Second element of input FIFO
244
245 Output:
246 r4 - r0: vector of rules, rule is invalid if it's 0, valid otherwise
247 -}
248 arbiter :: (ReducerSignal, ReducerSignal, ReducerSignal, ReducerSignal) ->  
249             Vector D5 Bit
250 arbiter (fp_out, mem_out, inp1, inp2) = (r4 +> (r3 +> (r2 +> (r1 +> (singleton r0)))))
251   where -- unpack parameters
252     fp_valid    = snd fp_out
253     next_valid  = snd mem_out
254     inp1_valid  = snd inp1
255     inp2_valid  = snd inp2
256     fp_discr    = snd (fst fp_out)
257     next_discr  = snd (fst mem_out)
258     inp1_discr  = snd (fst inp1)
259     inp2_discr  = snd (fst inp2)
260     -- Apply rules
261     r0  | (fp_valid /= Low) && (next_valid /= Low) && (fp_discr == next_discr)  
262                                       = High
263         | otherwise                   = Low
264     r1  | (fp_valid /= Low) && (inp1_valid /= Low) && (fp_discr == inp1_discr)  
265                                       = High
266         | otherwise                   = Low
267     r2  | (inp1_valid /= Low) && (inp2_valid /= Low) && 
268           (inp1_discr == inp2_discr)  = High
269         | otherwise                   = Low
270     r3  | inp1_valid /= Low           = High
271         | otherwise                   = Low
272     r4                                = High
273
274 {-
275 Controller determines which values are fed into the pipeline
276 and if the write enable flag for the Output RAM should be set
277 to true. Also determines how many values should be flushed from
278 the input FIFO.
279
280 Input:
281 fp_out: output of the adder pipeline
282 mem_out: data_out of the output logic
283 inp1: Head of input FIFO
284 inp2: Second element of input FIFO
285
286 Output:
287 grant: Signal that determines operants for the adder
288 enable: Number of values to be flushed from input buffer
289 wr_enable: Determine if value of the adder should be written to RAM
290 -}
291 controller :: (ReducerSignal, ReducerSignal, ReducerSignal, ReducerSignal) -> 
292                 ((RangedWord D2, RangedWord D2), RangedWord D2, Bit)
293 controller (fp_out,mem_out,inp1,inp2) = (grant,enable,wr_enable)
294   where
295     -- Arbiter determines which rules are valid
296     valid       = arbiter (fp_out,mem_out,inp1,inp2)
297     -- Determine which values should be fed to the adder
298     grant       = if (valid!(4 :: RangedWord D4) == High) 
299                   then (0,1) 
300                   else if ((drop d3 valid) == $(vectorTH [High,Low])) 
301                   then (0,0) 
302                   else if ((drop d2 valid) == $(vectorTH [High,Low,Low])) 
303                   then (1,0) 
304                   else if ((drop d1 valid) == $(vectorTH [High,Low,Low,Low])) 
305                   then (2,0) 
306                   else (2,2)
307     -- Determine if some values should be flushed from input FIFO
308     enable      = if (grant == (1,0)) 
309                   then 2 
310                   else if ((grant == (0,0)) || (grant == (2,0))) 
311                   then 1 
312                   else 0
313     -- Determine if the output value of the adder should be written to RAM
314     wr_enable'  = if (valid!(4 :: RangedWord D4) == High) 
315                   then Low 
316                   else if ((drop d3 valid) == $(vectorTH [High,Low])) 
317                   then Low 
318                   else if ((drop d2 valid) == $(vectorTH [High,Low,Low]))
319                   then High
320                   else if ((drop d1 valid) == $(vectorTH [High,Low,Low,Low])) 
321                   then High 
322                   else High
323     wr_enable   = if ((snd fp_out) /= Low) then wr_enable' else Low
324
325 {-
326 Reducer
327
328 Combines all the earlier defined functions. Uses the second implementation
329 of the input FIFO.
330
331 Parameter: 
332 'n': specifies the max discriminator value.
333
334 State: all the states of the used functions
335
336 Input: (value,index) combination
337
338 Output: reduced row
339 -}
340 {-# ANN reducer TopEntity #-}
341 reducer ::  ReducerState -> 
342             (DataInt, ArrayIndex) -> 
343             (ReducerState, OutputSignal)
344 reducer (State (discrstate,inputstate,fpadderstate,outputstate)) input = 
345   (State (discrstate',inputstate',fpadderstate',outputstate'),output)
346   where
347     (discrstate', discr_out)              = discriminator discrstate input
348     (inputstate',fifo_out1, fifo_out2)    = inputBuffer inputstate (
349                                             (fst discr_out), enable)
350     (fpadderstate', fp_out)               = fpAdder fpadderstate (fifo_out1, 
351                                                 fifo_out2, grant, mem_out)
352     discr                                 = snd (fst discr_out)
353     new_discr                             = fst (snd discr_out)
354     index                                 = snd (snd discr_out)
355     rdaddr                                = snd (fst fp_out)
356     wraddr                                = rdaddr
357     data_in                               = fst (fst fp_out)
358     (outputstate', mem_out, output)       = outputter outputstate (discr,
359                                             index, new_discr, data_in, rdaddr, 
360                                             wraddr, wr_enable)
361     (grant,enable,wr_enable)              = controller (fp_out, mem_out, 
362                                             fifo_out1, fifo_out2)
363
364
365 -- -------------------------------------------------------
366 -- -- Test Functions
367 -- -------------------------------------------------------            
368 --             
369 -- "Default" Run function
370 run func state [] = []
371 run func state (i:input) = o:out
372   where
373     (state', o) = func state i
374     out         = run func state' input
375 -- 
376 -- -- "Special" Run function, also outputs new state      
377 -- run' func state [] = ([],[])   
378 -- run' func state (i:input) = ((o:out), (state':ss))
379 --   where
380 --     (state',o)  = func state i
381 --     (out,ss)         = run' func state' input
382 -- 
383 -- Run reducer
384 runReducer =  ( reduceroutput
385               , validoutput
386               , equal
387               )
388   where
389     input = siminput
390     istate = initstate
391     output = run reducer istate input
392     reduceroutput = P.map fst (filter (\x -> (snd x) /= Low) output)
393     validoutput   = [P.foldl (+) 0 
394                       (P.map (\z -> toInteger (fst z)) 
395                         (filter (\x -> (snd x) == i) input)) | i <- [0..10]]
396     equal = [validoutput!!i == toInteger (fst (reduceroutput!!i)) | 
397               i <- [0..10]]
398 -- 
399 -- -- Generate infinite list of numbers between 1 and 'x'
400 -- randX :: Integer -> [Integer]   
401 -- randX x = randomRs (1,x) (unsafePerformIO newStdGen)
402 -- 
403 -- -- Generate random lists of indexes
404 -- randindex 15 i = randindex 1 i
405 -- randindex m i = (P.take n (repeat i)) P.++ (randindex (m+1) (i+1))
406 --   where
407 --     [n] = P.take 1 rnd
408 --     rnd = randomRs (1,m) (unsafePerformIO newStdGen)
409 -- 
410 -- -- Combine indexes and values to generate random input for the reducer    
411 -- randominput n x = P.zip data_in index_in 
412 --   where
413 --     data_in   = P.map (fromInteger :: Integer -> DataInt) (P.take n (randX x))
414 --     index_in  = P.map (fromInteger :: Integer -> ArrayIndex)
415 --                         (P.take n (randindex 7 0))
416 -- main = 
417 --   do
418 --     putStrLn (show runReducer)
419
420 -- simulate f input s = do
421 --   putStr "Input: "
422 --   putStr $ show input
423 --   putStr "\nInitial State: "
424 --   putStr $ show s
425 --   putStr "\n\n"
426 --   foldl1 (>>) (map (printOutput) output)
427 --   where
428 --     output = run f input s
429
430 initstate :: ReducerState
431 initstate = State
432   ( State ( (255 :: ArrayIndex)
433     , (7 :: SizedWord DiscrSize)
434     )
435   , State ( copy ((0::DataInt,0::Discr),Low)
436     , (2 :: RangedWord AdderDepth)
437     )
438   , State (copy ((0::DataInt,0::Discr),Low))
439   , State ( State (copy (0::DataInt))
440           , State (copy (0::DataInt))
441           , State (copy (0::ArrayIndex))
442           , (copy Low)
443           )
444   )
445
446 {-# ANN siminput TestInput #-}
447 siminput :: [(DataInt, ArrayIndex)]
448 siminput =  [(13,0)::(DataInt, ArrayIndex),(7,0)::(DataInt, ArrayIndex),(14,0)::(DataInt, ArrayIndex),(14,0)::(DataInt, ArrayIndex),(12,0)::(DataInt, ArrayIndex),(10,0)::(DataInt, ArrayIndex),(19,1)::(DataInt, ArrayIndex),(20,1)::(DataInt, ArrayIndex),(13,1)::(DataInt, ArrayIndex)
449             ,(5,1)::(DataInt, ArrayIndex),(9,1)::(DataInt, ArrayIndex),(16,1)::(DataInt, ArrayIndex),(15,1)::(DataInt, ArrayIndex),(10,2)::(DataInt, ArrayIndex),(13,2)::(DataInt, ArrayIndex),(3,2)::(DataInt, ArrayIndex),(9,2)::(DataInt, ArrayIndex),(19,2)::(DataInt, ArrayIndex),(5,3)::(DataInt, ArrayIndex)
450             ,(5,3)::(DataInt, ArrayIndex),(10,3)::(DataInt, ArrayIndex),(17,3)::(DataInt, ArrayIndex),(14,3)::(DataInt, ArrayIndex),(5,3)::(DataInt, ArrayIndex),(15,3)::(DataInt, ArrayIndex),(11,3)::(DataInt, ArrayIndex),(5,3)::(DataInt, ArrayIndex),(1,3)::(DataInt, ArrayIndex),(8,4)::(DataInt, ArrayIndex)
451             ,(20,4)::(DataInt, ArrayIndex),(8,4)::(DataInt, ArrayIndex),(1,4)::(DataInt, ArrayIndex),(11,4)::(DataInt, ArrayIndex),(10,4)::(DataInt, ArrayIndex),(13,5)::(DataInt, ArrayIndex),(18,5)::(DataInt, ArrayIndex),(5,5)::(DataInt, ArrayIndex),(6,5)::(DataInt, ArrayIndex),(6,5)::(DataInt, ArrayIndex)
452             ,(4,6)::(DataInt, ArrayIndex),(4,6)::(DataInt, ArrayIndex),(11,6)::(DataInt, ArrayIndex),(11,6)::(DataInt, ArrayIndex),(11,6)::(DataInt, ArrayIndex),(1,6)::(DataInt, ArrayIndex),(11,6)::(DataInt, ArrayIndex),(3,6)::(DataInt, ArrayIndex),(12,6)::(DataInt, ArrayIndex),(12,6)::(DataInt, ArrayIndex)
453             ,(2,6)::(DataInt, ArrayIndex),(14,6)::(DataInt, ArrayIndex),(11,7)::(DataInt, ArrayIndex),(13,7)::(DataInt, ArrayIndex),(17,7)::(DataInt, ArrayIndex),(9,7)::(DataInt, ArrayIndex),(19,8)::(DataInt, ArrayIndex),(4,9)::(DataInt, ArrayIndex),(18,10)::(DataInt, ArrayIndex)
454             ,(6,10)::(DataInt, ArrayIndex),(18,11)::(DataInt, ArrayIndex),(1,12)::(DataInt, ArrayIndex),(3,12)::(DataInt, ArrayIndex),(14,12)::(DataInt, ArrayIndex),(18,12)::(DataInt, ArrayIndex),(14,12)::(DataInt, ArrayIndex),(6,13)::(DataInt, ArrayIndex)
455             ,(9,13)::(DataInt, ArrayIndex),(11,14)::(DataInt, ArrayIndex),(4,14)::(DataInt, ArrayIndex),(1,14)::(DataInt, ArrayIndex),(14,14)::(DataInt, ArrayIndex),(14,14)::(DataInt, ArrayIndex),(6,14)::(DataInt, ArrayIndex),(11,15)::(DataInt, ArrayIndex)
456             ,(13,15)::(DataInt, ArrayIndex),(7,15)::(DataInt, ArrayIndex),(2,16)::(DataInt, ArrayIndex),(16,16)::(DataInt, ArrayIndex),(17,16)::(DataInt, ArrayIndex),(5,16)::(DataInt, ArrayIndex),(20,16)::(DataInt, ArrayIndex),(17,16)::(DataInt, ArrayIndex)
457             ,(14,16)::(DataInt, ArrayIndex),(18,17)::(DataInt, ArrayIndex),(13,17)::(DataInt, ArrayIndex),(1,17)::(DataInt, ArrayIndex),(19,18)::(DataInt, ArrayIndex),(1,18)::(DataInt, ArrayIndex),(20,18)::(DataInt, ArrayIndex),(4,18)::(DataInt, ArrayIndex)
458             ,(5,19)::(DataInt, ArrayIndex),(4,19)::(DataInt, ArrayIndex),(6,19)::(DataInt, ArrayIndex),(19,19)::(DataInt, ArrayIndex),(4,19)::(DataInt, ArrayIndex),(3,19)::(DataInt, ArrayIndex),(7,19)::(DataInt, ArrayIndex),(13,19)::(DataInt, ArrayIndex),(19,19)::(DataInt, ArrayIndex)
459             ,(8,19)::(DataInt, ArrayIndex)
460             ]