Create a VHDL proc for each state variable.
[matthijs/master-project/cλash.git] / Translator.hs
index b2ce3ef59d679f95afb9b8aa8010af708b4bd21b..66e2cb895aaafc62744e32b7838d7279ed5cd15c 100644 (file)
@@ -45,7 +45,7 @@ main =
           --core <- GHC.compileToCoreSimplified "Adders.hs"
           core <- GHC.compileToCoreSimplified "Adders.hs"
           --liftIO $ printBinds (cm_binds core)
-          let binds = Maybe.mapMaybe (findBind (cm_binds core)) ["shalf_adder"]
+          let binds = Maybe.mapMaybe (findBind (cm_binds core)) ["dff"]
           liftIO $ printBinds binds
           -- Turn bind into VHDL
           let (vhdl, sess) = State.runState (mkVHDL binds) (VHDLSession 0 [])
@@ -124,7 +124,8 @@ expandExpr binds lam@(Lam b expr) = do
   -- Find the type of the binder
   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
   -- Create signal names for the binder
-  let arg_signal = getPortNameMapForTy ("xxx") arg_ty
+  -- TODO: We assume arguments are ports here
+  let arg_signal = getPortNameMapForTy signal_name arg_ty (useAsPort arg_ty)
   -- Create the corresponding signal declarations
   let signal_decls = mkSignalsFromMap arg_signal
   -- Add the binder to the list of binds
@@ -256,7 +257,8 @@ expandApplicationExpr binds ty f args = do
   -- Bind each of the input ports to the expanded arguments
   let inmaps = concat $ zipWith createAssocElems inports arg_res_signals
   -- Create signal names for our result
-  let res_signal = getPortNameMapForTy (appname ++ "_out") ty
+  -- TODO: We assume the result is a port here
+  let res_signal = getPortNameMapForTy (appname ++ "_out") ty (useAsPort ty)
   -- Create the corresponding signal declarations
   let signal_decls = mkSignalsFromMap res_signal
   -- Bind each of the output ports to our output signals
@@ -361,43 +363,79 @@ expandBind (Rec _) = error "Recursive binders not supported"
 
 expandBind bind@(NonRec var expr) = do
   -- Create the function signature
-  hwfunc <- mkHWFunction bind
   let ty = CoreUtils.exprType expr
   let hsfunc = mkHsFunction var ty
+  hwfunc <- mkHWFunction bind hsfunc
   -- Add it to the session
   addFunc hsfunc hwfunc 
-  arch <- getArchitecture hwfunc expr
-  let entity = getEntity hwfunc
+  arch <- getArchitecture hsfunc hwfunc expr
+  -- Give every entity a clock port
+  -- TODO: Omit this for stateless entities
+  let clk_port = AST.IfaceSigDec (mkVHDLId "clk") AST.In vhdl_bit_ty
+  let entity = getEntity hwfunc [clk_port]
   return $ [
     AST.LUEntity entity,
     AST.LUArch arch ]
 
 getArchitecture ::
-  HWFunction                -- The function to generate an architecture for
+  HsFunction                -- The function interface
+  -> HWFunction             -- The function to generate an architecture for
   -> CoreExpr               -- The expression that is bound to the function
   -> VHDLState AST.ArchBody -- The resulting architecture
    
-getArchitecture hwfunc expr = do
+getArchitecture hsfunc hwfunc expr = do
   -- Unpack our hwfunc
   let HWFunction vhdl_id inports outport = hwfunc
   -- Expand the expression into an architecture body
   (signal_decls, statements, arg_signals, res_signal) <- expandExpr [] expr
-  let inport_assigns = concat $ zipWith createSignalAssignments arg_signals inports
-  let outport_assigns = createSignalAssignments outport res_signal
+  let (inport_assigns, instate_map)  = concat_elements $ unzip $ zipWith3 createSignalAssignments arg_signals inports (hsArgs hsfunc)
+  let (outport_assigns, outstate_map) = createSignalAssignments outport res_signal (hsRes hsfunc)
+  let state_procs = map AST.CSPSm $ createStateProcs (sortMap instate_map) (sortMap outstate_map)
   return $ AST.ArchBody
     (AST.unsafeVHDLBasicId "structural")
     (AST.NSimple vhdl_id)
     (map AST.BDISD signal_decls)
-    (inport_assigns ++ outport_assigns ++ statements)
+    (state_procs ++ inport_assigns ++ outport_assigns ++ statements)
+
+-- | Sorts a map modeled as a list of (key,value) pairs by key
+sortMap :: Ord a => [(a, b)] -> [(a, b)]
+sortMap = List.sortBy (\(a, _) (b, _) -> compare a b)
+
+-- | Generate procs for state variables
+createStateProcs ::
+  [(Int, AST.VHDLId)]
+                    -- ^ The sorted list of signals that should be assigned
+                    --   to each state
+  -> [(Int, AST.VHDLId)]   
+                    -- ^ The sorted list of signals that contain each new state
+  -> [AST.ProcSm]   -- ^ The resulting procs
+
+createStateProcs ((old_num, old_id):olds) ((new_num, new_id):news) =
+  if (old_num == new_num)
+    then
+      AST.ProcSm label [clk] [statement] : createStateProcs olds news
+    else
+      error "State numbers don't match!"
+  where
+    label       = mkVHDLId $ "state_" ++ (show old_num)
+    clk         = mkVHDLId "clk"
+    rising_edge = AST.NSimple $ mkVHDLId "rising_edge"
+    wform       = AST.Wform [AST.WformElem (AST.PrimName $ AST.NSimple $ new_id) Nothing]
+    assign      = AST.SigAssign (AST.NSimple old_id) wform
+    rising_edge_clk = AST.PrimFCall $ AST.FCall rising_edge [Nothing AST.:=>: (AST.ADName $ AST.NSimple clk)]
+    statement   = AST.IfSm rising_edge_clk [assign] [] Nothing
+
+createStateProcs [] [] = []
 
 -- Generate a VHDL entity declaration for the given function
-getEntity :: HWFunction -> AST.EntityDec  
-getEntity (HWFunction vhdl_id inports outport) = 
+getEntity :: HWFunction -> [AST.IfaceSigDec] -> AST.EntityDec  
+getEntity (HWFunction vhdl_id inports outport) extra_ports 
   AST.EntityDec vhdl_id ports
   where
     ports = 
       (concat $ map (mkIfaceSigDecs AST.In) inports)
       ++ mkIfaceSigDecs AST.Out outport
+      ++ extra_ports
 
 mkIfaceSigDecs ::
   AST.Mode                        -- The port's mode (In or Out)
@@ -410,17 +448,23 @@ mkIfaceSigDecs mode (Single (port_id, ty)) =
 mkIfaceSigDecs mode (Tuple ports) =
   concat $ map (mkIfaceSigDecs mode) ports
 
+-- Unused values (state) don't generate ports
+mkIfaceSigDecs mode Unused =
+  []
+
 -- Create concurrent assignments of one map of signals to another. The maps
 -- should have a similar form.
 createSignalAssignments ::
-  SignalNameMap         -- The signals to assign to
-  -> SignalNameMap      -- The signals to assign
-  -> [AST.ConcSm]                  -- The resulting assignments
+  SignalNameMap           -- The signals to assign to
+  -> SignalNameMap        -- The signals to assign
+  -> HsUseMap             -- What function does each of the signals have?
+  -> ([AST.ConcSm],       -- The resulting assignments
+      [(Int, AST.VHDLId)]) -- The resulting state -> signal mappings
 
 -- A simple assignment of one signal to another (greatly complicated because
 -- signal assignments can be conditional with multiple conditions in VHDL).
-createSignalAssignments (Single (dst, _)) (Single (src, _)) =
-    [AST.CSSASm assign]
+createSignalAssignments (Single (dst, _)) (Single (src, _)) (Single Port)=
+    ([AST.CSSASm assign], [])
   where
     src_name  = AST.NSimple src
     src_expr  = AST.PrimName src_name
@@ -428,11 +472,19 @@ createSignalAssignments (Single (dst, _)) (Single (src, _)) =
     dst_name  = (AST.NSimple dst)
     assign    = dst_name AST.:<==: (AST.ConWforms [] src_wform Nothing)
 
-createSignalAssignments (Tuple dsts) (Tuple srcs) =
-  concat $ zipWith createSignalAssignments dsts srcs
+createSignalAssignments (Tuple dsts) (Tuple srcs) (Tuple uses) =
+  concat_elements $ unzip $ zipWith3 createSignalAssignments dsts srcs uses
+
+createSignalAssignments Unused (Single (src, _)) (Single (State n)) =
+  -- Write state
+  ([], [(n, src)])
 
-createSignalAssignments dst src =
-  error $ "Non matching source and destination: " ++ show dst ++ "\nand\n" ++  show src
+createSignalAssignments (Single (dst, _)) Unused (Single (State n)) =
+  -- Read state
+  ([], [(n, dst)])
+
+createSignalAssignments dst src use =
+  error $ "Non matching source and destination: " ++ show dst ++ " <= " ++  show src ++ " (Used as " ++ show use ++ ")"
 
 type SignalNameMap = HsValueMap (AST.VHDLId, AST.TypeMark)
 
@@ -442,43 +494,59 @@ type SignalNameMap = HsValueMap (AST.VHDLId, AST.TypeMark)
 data HsValueMap mapto =
   Tuple [HsValueMap mapto]
   | Single mapto
+  | Unused
   deriving (Show, Eq)
 
 -- | Creates a HsValueMap with the same structure as the given type, using the
 --   given function for mapping the single types.
 mkHsValueMap ::
-  (Type -> HsValueMap mapto)    -- ^ A function to map single value Types
+  ((Type, s) -> (HsValueMap mapto, s))
+                                -- ^ A function to map single value Types
                                 --   (basically anything but tuples) to a
                                 --   HsValueMap (not limited to the Single
-                                --   constructor)
+                                --   constructor) Also accepts and produces a
+                                --   state that will be passed on between
+                                --   each call to the function.
+  -> s                          -- ^ The initial state
   -> Type                       -- ^ The type to map to a HsValueMap
-  -> HsValueMap mapto           -- ^ The resulting map
+  -> (HsValueMap mapto, s)      -- ^ The resulting map and state
 
-mkHsValueMap f ty =
+mkHsValueMap f ty =
   case Type.splitTyConApp_maybe ty of
     Just (tycon, args) ->
       if (TyCon.isTupleTyCon tycon) 
         then
+          let (args', s') = mapTuple f s args in
           -- Handle tuple construction especially
-          Tuple (map (mkHsValueMap f) args)
+          (Tuple args', s')
         else
           -- And let f handle the rest
-          f ty
+          f (ty, s)
     -- And let f handle the rest
-    Nothing -> f ty
+    Nothing -> f (ty, s)
+  where
+    mapTuple f s (ty:tys) =
+      let (map, s') = mkHsValueMap f s ty in
+      let (maps, s'') = mapTuple f s' tys in
+      (map: maps, s'')
+    mapTuple f s [] = ([], s)
 
 -- Generate a port name map (or multiple for tuple types) in the given direction for
 -- each type given.
-getPortNameMapForTys :: String -> Int -> [Type] -> [SignalNameMap]
-getPortNameMapForTys prefix num [] = [] 
-getPortNameMapForTys prefix num (t:ts) =
-  (getPortNameMapForTy (prefix ++ show num) t) : getPortNameMapForTys prefix (num + 1) ts
+getPortNameMapForTys :: String -> Int -> [Type] -> [HsUseMap] -> [SignalNameMap]
+getPortNameMapForTys prefix num [] [] = [] 
+getPortNameMapForTys prefix num (t:ts) (u:us) =
+  (getPortNameMapForTy (prefix ++ show num) t u) : getPortNameMapForTys prefix (num + 1) ts us
+
+getPortNameMapForTy :: String -> Type -> HsUseMap -> SignalNameMap
+getPortNameMapForTy name _ (Single (State _)) =
+  Unused
 
-getPortNameMapForTy :: String -> Type -> SignalNameMap
-getPortNameMapForTy name ty =
+getPortNameMapForTy name ty use =
   if (TyCon.isTupleTyCon tycon) then
+    let (Tuple uses) = use in
     -- Expand tuples we find
-    Tuple (getPortNameMapForTys name 0 args)
+    Tuple (getPortNameMapForTys name 0 args uses)
   else -- Assume it's a type constructor application, ie simple data type
     Single ((AST.unsafeVHDLBasicId name), (vhdl_ty ty))
   where
@@ -495,41 +563,45 @@ data HWFunction = HWFunction { -- A function that is available in hardware
 -- output ports.
 mkHWFunction ::
   CoreBind                                   -- The core binder to generate the interface for
+  -> HsFunction                              -- The HsFunction describing the function
   -> VHDLState HWFunction                    -- The function interface
 
-mkHWFunction (NonRec var expr) =
+mkHWFunction (NonRec var expr) hsfunc =
     return $ HWFunction (mkVHDLId name) inports outport
   where
     name = getOccString var
     ty = CoreUtils.exprType expr
-    (fargs, res) = Type.splitFunTys ty
-    args = if length fargs == 1 then fargs else (init fargs)
-    --state = if length fargs == 1 then () else (last fargs)
+    (args, res) = Type.splitFunTys ty
     inports = case args of
       -- Handle a single port specially, to prevent an extra 0 in the name
-      [port] -> [getPortNameMapForTy "portin" port]
-      ps     -> getPortNameMapForTys "portin" 0 ps
-    outport = getPortNameMapForTy "portout" res
+      [port] -> [getPortNameMapForTy "portin" port (head $ hsArgs hsfunc)]
+      ps     -> getPortNameMapForTys "portin" 0 ps (hsArgs hsfunc)
+    outport = getPortNameMapForTy "portout" res (hsRes hsfunc)
 
-mkHWFunction (Rec _) =
+mkHWFunction (Rec _) =
   error "Recursive binders not supported"
 
 -- | How is a given (single) value in a function's type (ie, argument or
 -- return value) used?
 data HsValueUse = 
-  Port -- ^ Use it as a port (input or output)
-  | State --- ^ Use it as state (input or output)
+  Port        -- ^ Use it as a port (input or output)
+  | State Int -- ^ Use it as state (input or output). The int is used to
+              --   match input state to output state.
   deriving (Show, Eq)
 
-useAsPort = mkHsValueMap (\x -> Single Port)
-useAsState = mkHsValueMap (\x -> Single State)
+useAsPort :: Type -> HsUseMap
+useAsPort = fst . (mkHsValueMap (\(ty, s) -> (Single Port, s)) 0)
+useAsState :: Type -> HsUseMap
+useAsState = fst . (mkHsValueMap (\(ty, s) -> (Single $ State s, s + 1)) 0)
+
+type HsUseMap = HsValueMap HsValueUse
 
 -- | This type describes a particular use of a Haskell function and is used to
 --   look up an appropriate hardware description.  
 data HsFunction = HsFunction {
   hsName :: String,                      -- ^ What was the name of the original Haskell function?
-  hsArgs :: [HsValueMap HsValueUse],     -- ^ How are the arguments used?
-  hsRes  :: HsValueMap HsValueUse        -- ^ How is the result value used?
+  hsArgs :: [HsUseMap],                  -- ^ How are the arguments used?
+  hsRes  :: HsUseMap                     -- ^ How is the result value used?
 } deriving (Show, Eq)
 
 -- | Translate a function application to a HsFunction. i.e., which function
@@ -543,9 +615,8 @@ appToHsFunction ::
 appToHsFunction f args ty =
   HsFunction hsname hsargs hsres
   where
-    mkPort = \x -> Single Port
-    hsargs = map (mkHsValueMap mkPort . CoreUtils.exprType) args
-    hsres  = mkHsValueMap mkPort ty
+    hsargs = map (useAsPort . CoreUtils.exprType) args
+    hsres  = useAsPort ty
     hsname = getOccString f
 
 -- | Translate a top level function declaration to a HsFunction. i.e., which
@@ -626,6 +697,12 @@ uniqueName name = do
 mkVHDLId :: String -> AST.VHDLId
 mkVHDLId = AST.unsafeVHDLBasicId
 
+-- Concatenate each of the lists of lists inside the given tuple.
+-- Since the element types in the lists might differ, we can't generalize
+-- this (unless we pass in f twice).
+concat_elements :: ([[a]], [[b]]) -> ([a], [b])
+concat_elements (a, b) = (concat a, concat b)
+
 builtin_funcs = 
   [ 
     (HsFunction "hwxor" [(Single Port), (Single Port)] (Single Port), HWFunction (mkVHDLId "hwxor") [Single (mkVHDLId "a", vhdl_bit_ty), Single (mkVHDLId "b", vhdl_bit_ty)] (Single (mkVHDLId "o", vhdl_bit_ty))),