Connect resetn port to states.
[matthijs/master-project/cλash.git] / reducer.hs
index 65730f4ec084f652c029f1e742d6d5a368e57e9e..ecc0051c3a896dfb9407a3a0346ff4d720e71d81 100644 (file)
@@ -1,4 +1,4 @@
-{-# LANGUAGE TypeOperators, TemplateHaskell #-}
+{-# LANGUAGE TypeOperators, TemplateHaskell, FlexibleContexts, TypeFamilies #-}
 module Reducer where
 
 import System.Random
@@ -18,8 +18,6 @@ type DataInt        = SizedWord DataSize
 type ArrayIndex     = SizedWord IndexSize
 type Discr          = RangedWord DiscrRange
 
-type RAM a          = Vector (DiscrRange :+: D1) a
-
 type ReducerState   = State ( DiscrState
                       , InputState
                       , FpAdderState
@@ -30,8 +28,6 @@ type ReducerSignal  = ( ( DataInt
                         )
                       , Bit
                       )
-
-type MemState a      = State (RAM a)
                       
 type OutputSignal   = ( (DataInt
                         , ArrayIndex
@@ -49,10 +45,10 @@ type InputState     = State ( Vector (AdderDepth :+: D1) ReducerSignal
 
 type FpAdderState   = State (Vector AdderDepth ReducerSignal)
 
-type OutputState    = State ( MemState DataInt
-                            , MemState DataInt
-                            , RAM ArrayIndex
-                            , RAM Bit
+type OutputState    = State ( MemState DiscrRange DataInt
+                            , MemState DiscrRange DataInt
+                            , MemState DiscrRange ArrayIndex
+                            , RAM DiscrRange Bit
                       )
 {-
 Discriminator adds a discriminator to each input value
@@ -109,8 +105,8 @@ valid indicates if the output contains a valid discriminator
 -}
 inputBuffer ::  InputState -> 
                 ((DataInt, Discr), RangedWord D2) -> 
-                (InputState, (ReducerSignal, ReducerSignal))            
-inputBuffer (State (mem,wrptr)) (inp,enable) = (State (mem',wrptr'),(out1, out2))
+                (InputState, ReducerSignal, ReducerSignal)
+inputBuffer (State (mem,wrptr)) (inp,enable) = (State (mem',wrptr'),out1, out2)
   where
     out1                  = last mem -- output head of FIFO
     out2                  = last (init mem) -- output 2nd element
@@ -177,39 +173,6 @@ fpAdder (State state) (input1, input2, grant, mem_out) = (State state', output)
               | otherwise         = High
     -- Shift addition of the two operants into the pipeline
     state'    = (((operant1 + operant2),discr),valid) +> (init state)
-    
-
-{- 
-first attempt at BlockRAM
-
-State:
-mem: content of the RAM
-
-Input:
-data_in: input value to be written to 'mem' at location 'wraddr'
-rdaddr: read address
-wraddr: write address
-wrenable: write enable flag
-
-Output:
-data_out: value of 'mem' at location 'rdaddr'
--}
-{-# NOINLINE blockRAM #-}
-blockRAM :: (MemState a) -> 
-            ( a
-            , Discr
-            , Discr
-            , Bit
-            ) -> 
-            ((MemState a), a )
-blockRAM (State mem) (data_in, rdaddr, wraddr, wrenable) = 
-  ((State mem'), data_out)
-  where
-    data_out  = mem!rdaddr
-    -- Only write data_in to memory if write is enabled
-    mem' = case wrenable of
-      Low   ->  mem
-      High  ->  replace mem wraddr data_in
 
 {-
 Output logic - Determines when values are released from blockram to the output
@@ -241,14 +204,13 @@ outputter ::  OutputState ->
               , Discr
               , Bit
               ) -> 
-              (OutputState, (ReducerSignal, OutputSignal))                 
+              (OutputState, ReducerSignal, OutputSignal)
 outputter (State (mem1, mem2, lut, valid))
   (discr, index, new_discr, data_in, rdaddr, wraddr, wrenable) = 
-  ((State (mem1', mem2', lut', valid')), (data_out, output))
+  ((State (mem1', mem2', lut', valid')), data_out, output)
   where
     -- Lut is updated when new discriminator/index combination enters system        
-    lut'    | new_discr /= Low  = replace lut discr index
-            | otherwise         = lut
+    (lut', lut_out)             = blockRAM lut index discr discr new_discr
     -- Location becomes invalid when Reduced row leaves system        
     valid'' | (new_discr /= Low) && ((valid!discr) /= Low) = 
                                   replace valid discr Low
@@ -256,16 +218,8 @@ outputter (State (mem1, mem2, lut, valid))
     -- Location becomes invalid when it is fed back into the pipeline
     valid'  | wrenable == Low   = replace valid'' rdaddr Low
             | otherwise         = replace valid'' wraddr High
-    (mem1', mem_out1)           = blockRAM mem1 ( data_in
-                                                , rdaddr
-                                                , wraddr
-                                                , wrenable
-                                                )
-    (mem2', mem_out2)           = blockRAM mem2 ( data_in
-                                            , discr
-                                            , wraddr
-                                            , wrenable
-                                            )
+    (mem1', mem_out1)           = blockRAM mem1 data_in rdaddr wraddr wrenable
+    (mem2', mem_out2)           = blockRAM mem2 data_in discr wraddr wrenable
     data_out                    = ( ( (mem_out1)
                                     , rdaddr
                                     )
@@ -274,7 +228,7 @@ outputter (State (mem1, mem2, lut, valid))
     -- Reduced row is released when new discriminator enters system
     -- And the position at the discriminator holds a valid value
     output                      = ( ( (mem_out2)
-                                    , (lut!discr)
+                                    , (lut_out)
                                     )
                                   , (new_discr `hwand` (valid!discr))
                                   )
@@ -391,7 +345,7 @@ reducer (State (discrstate,inputstate,fpadderstate,outputstate)) input =
   (State (discrstate',inputstate',fpadderstate',outputstate'),output)
   where
     (discrstate', discr_out)              = discriminator discrstate input
-    (inputstate',(fifo_out1, fifo_out2))  = inputBuffer inputstate (
+    (inputstate',fifo_out1, fifo_out2)    = inputBuffer inputstate (
                                             (fst discr_out), enable)
     (fpadderstate', fp_out)               = fpAdder fpadderstate (fifo_out1, 
                                                 fifo_out2, grant, mem_out)
@@ -401,7 +355,7 @@ reducer (State (discrstate,inputstate,fpadderstate,outputstate)) input =
     rdaddr                                = snd (fst fp_out)
     wraddr                                = rdaddr
     data_in                               = fst (fst fp_out)
-    (outputstate', (mem_out, output))     = outputter outputstate (discr, 
+    (outputstate', mem_out, output)       = outputter outputstate (discr,
                                             index, new_discr, data_in, rdaddr, 
                                             wraddr, wr_enable)
     (grant,enable,wr_enable)              = controller (fp_out, mem_out, 
@@ -484,22 +438,23 @@ initstate = State
   , State (copy ((0::DataInt,0::Discr),Low))
   , State ( State (copy (0::DataInt))
           , State (copy (0::DataInt))
-          , (copy (0::ArrayIndex))
+          , State (copy (0::ArrayIndex))
           , (copy Low)
           )
   )
 
+{-# ANN siminput TestInput #-}
 siminput :: [(DataInt, ArrayIndex)]
-siminput =  [(13,0),(7,0),(14,0),(14,0),(12,0),(10,0),(19,1),(20,1),(13,1)
-            ,(5,1),(9,1),(16,1),(15,1),(10,2),(13,2),(3,2),(9,2),(19,2),(5,3)
-            ,(5,3),(10,3),(17,3),(14,3),(5,3),(15,3),(11,3),(5,3),(1,3),(8,4)
-            ,(20,4),(8,4),(1,4),(11,4),(10,4),(13,5),(18,5),(5,5),(6,5),(6,5)
-            ,(4,6),(4,6),(11,6),(11,6),(11,6),(1,6),(11,6),(3,6),(12,6),(12,6)
-            ,(2,6),(14,6),(11,7),(13,7),(17,7),(9,7),(19,8),(4,9),(18,10)
-            ,(6,10),(18,11),(1,12),(3,12),(14,12),(18,12),(14,12),(6,13)
-            ,(9,13),(11,14),(4,14),(1,14),(14,14),(14,14),(6,14),(11,15)
-            ,(13,15),(7,15),(2,16),(16,16),(17,16),(5,16),(20,16),(17,16)
-            ,(14,16),(18,17),(13,17),(1,17),(19,18),(1,18),(20,18),(4,18)
-            ,(5,19),(4,19),(6,19),(19,19),(4,19),(3,19),(7,19),(13,19),(19,19)
-            ,(8,19)
+siminput =  [(13,0)::(DataInt, ArrayIndex),(7,0)::(DataInt, ArrayIndex),(14,0)::(DataInt, ArrayIndex),(14,0)::(DataInt, ArrayIndex),(12,0)::(DataInt, ArrayIndex),(10,0)::(DataInt, ArrayIndex),(19,1)::(DataInt, ArrayIndex),(20,1)::(DataInt, ArrayIndex),(13,1)::(DataInt, ArrayIndex)
+            ,(5,1)::(DataInt, ArrayIndex),(9,1)::(DataInt, ArrayIndex),(16,1)::(DataInt, ArrayIndex),(15,1)::(DataInt, ArrayIndex),(10,2)::(DataInt, ArrayIndex),(13,2)::(DataInt, ArrayIndex),(3,2)::(DataInt, ArrayIndex),(9,2)::(DataInt, ArrayIndex),(19,2)::(DataInt, ArrayIndex),(5,3)::(DataInt, ArrayIndex)
+            ,(5,3)::(DataInt, ArrayIndex),(10,3)::(DataInt, ArrayIndex),(17,3)::(DataInt, ArrayIndex),(14,3)::(DataInt, ArrayIndex),(5,3)::(DataInt, ArrayIndex),(15,3)::(DataInt, ArrayIndex),(11,3)::(DataInt, ArrayIndex),(5,3)::(DataInt, ArrayIndex),(1,3)::(DataInt, ArrayIndex),(8,4)::(DataInt, ArrayIndex)
+            ,(20,4)::(DataInt, ArrayIndex),(8,4)::(DataInt, ArrayIndex),(1,4)::(DataInt, ArrayIndex),(11,4)::(DataInt, ArrayIndex),(10,4)::(DataInt, ArrayIndex),(13,5)::(DataInt, ArrayIndex),(18,5)::(DataInt, ArrayIndex),(5,5)::(DataInt, ArrayIndex),(6,5)::(DataInt, ArrayIndex),(6,5)::(DataInt, ArrayIndex)
+            ,(4,6)::(DataInt, ArrayIndex),(4,6)::(DataInt, ArrayIndex),(11,6)::(DataInt, ArrayIndex),(11,6)::(DataInt, ArrayIndex),(11,6)::(DataInt, ArrayIndex),(1,6)::(DataInt, ArrayIndex),(11,6)::(DataInt, ArrayIndex),(3,6)::(DataInt, ArrayIndex),(12,6)::(DataInt, ArrayIndex),(12,6)::(DataInt, ArrayIndex)
+            ,(2,6)::(DataInt, ArrayIndex),(14,6)::(DataInt, ArrayIndex),(11,7)::(DataInt, ArrayIndex),(13,7)::(DataInt, ArrayIndex),(17,7)::(DataInt, ArrayIndex),(9,7)::(DataInt, ArrayIndex),(19,8)::(DataInt, ArrayIndex),(4,9)::(DataInt, ArrayIndex),(18,10)::(DataInt, ArrayIndex)
+            ,(6,10)::(DataInt, ArrayIndex),(18,11)::(DataInt, ArrayIndex),(1,12)::(DataInt, ArrayIndex),(3,12)::(DataInt, ArrayIndex),(14,12)::(DataInt, ArrayIndex),(18,12)::(DataInt, ArrayIndex),(14,12)::(DataInt, ArrayIndex),(6,13)::(DataInt, ArrayIndex)
+            ,(9,13)::(DataInt, ArrayIndex),(11,14)::(DataInt, ArrayIndex),(4,14)::(DataInt, ArrayIndex),(1,14)::(DataInt, ArrayIndex),(14,14)::(DataInt, ArrayIndex),(14,14)::(DataInt, ArrayIndex),(6,14)::(DataInt, ArrayIndex),(11,15)::(DataInt, ArrayIndex)
+            ,(13,15)::(DataInt, ArrayIndex),(7,15)::(DataInt, ArrayIndex),(2,16)::(DataInt, ArrayIndex),(16,16)::(DataInt, ArrayIndex),(17,16)::(DataInt, ArrayIndex),(5,16)::(DataInt, ArrayIndex),(20,16)::(DataInt, ArrayIndex),(17,16)::(DataInt, ArrayIndex)
+            ,(14,16)::(DataInt, ArrayIndex),(18,17)::(DataInt, ArrayIndex),(13,17)::(DataInt, ArrayIndex),(1,17)::(DataInt, ArrayIndex),(19,18)::(DataInt, ArrayIndex),(1,18)::(DataInt, ArrayIndex),(20,18)::(DataInt, ArrayIndex),(4,18)::(DataInt, ArrayIndex)
+            ,(5,19)::(DataInt, ArrayIndex),(4,19)::(DataInt, ArrayIndex),(6,19)::(DataInt, ArrayIndex),(19,19)::(DataInt, ArrayIndex),(4,19)::(DataInt, ArrayIndex),(3,19)::(DataInt, ArrayIndex),(7,19)::(DataInt, ArrayIndex),(13,19)::(DataInt, ArrayIndex),(19,19)::(DataInt, ArrayIndex)
+            ,(8,19)::(DataInt, ArrayIndex)
             ]