Find state arguments / results in top level functions.
[matthijs/master-project/cλash.git] / Translator.hs
index 93eeacd30792761438d90d92e03b78c690bad0ef..2f6a2922d4bad2098cc4aa722af698b00998c1bf 100644 (file)
@@ -45,7 +45,7 @@ main =
           --core <- GHC.compileToCoreSimplified "Adders.hs"
           core <- GHC.compileToCoreSimplified "Adders.hs"
           --liftIO $ printBinds (cm_binds core)
-          let binds = Maybe.mapMaybe (findBind (cm_binds core)) ["full_adder", "half_adder"]
+          let binds = Maybe.mapMaybe (findBind (cm_binds core)) ["shalf_adder"]
           liftIO $ printBinds binds
           -- Turn bind into VHDL
           let vhdl = State.evalState (mkVHDL binds) (VHDLSession 0 [])
@@ -57,16 +57,11 @@ main =
     mkVHDL binds = do
       -- Add the builtin functions
       mapM (uncurry addFunc) builtin_funcs
-      -- Get the function signatures
-      funcs <- mapM mkHWFunction binds
-      -- Add them to the session
-      mapM (uncurry addFunc) funcs
-      let entities = map getEntity (snd $ unzip funcs)
-      -- Create architectures for them
-      archs <- mapM getArchitecture binds
+      -- Create entities and architectures for them
+      units <- mapM expandBind binds
       return $ AST.DesignFile 
         []
-        ((map AST.LUEntity entities) ++ (map AST.LUArch archs))
+        (concat units)
 
 printTarget (Target (TargetFile file (Just x)) obj Nothing) =
   print $ show file
@@ -357,16 +352,34 @@ mapOutputPorts (Single (portname, _)) (Single (signalname, _)) =
 mapOutputPorts (Tuple ports) (Tuple signals) =
   concat (zipWith mapOutputPorts ports signals)
 
+expandBind ::
+  CoreBind                        -- The binder to expand into VHDL
+  -> VHDLState [AST.LibraryUnit]  -- The resulting VHDL
+
+expandBind (Rec _) = error "Recursive binders not supported"
+
+expandBind bind@(NonRec var expr) = do
+  -- Create the function signature
+  hwfunc <- mkHWFunction bind
+  let ty = CoreUtils.exprType expr
+  let hsfunc = mkHsFunction var ty
+  -- Add it to the session
+  addFunc hsfunc hwfunc 
+  arch <- getArchitecture hwfunc expr
+  let entity = getEntity hwfunc
+  return $ [
+    AST.LUEntity entity,
+    AST.LUArch arch ]
+
 getArchitecture ::
-  CoreBind                  -- The binder to expand into an architecture
+  HWFunction                -- The function to generate an architecture for
+  -> CoreExpr               -- The expression that is bound to the function
   -> VHDLState AST.ArchBody -- The resulting architecture
    
-getArchitecture (Rec _) = error "Recursive binders not supported"
-
-getArchitecture (NonRec var expr) = do
-  let name = (getOccString var)
-  HWFunction vhdl_id inports outport <- getHWFunc (HsFunction name [] (Tuple []))
-  sess <- State.get
+getArchitecture hwfunc expr = do
+  -- Unpack our hwfunc
+  let HWFunction vhdl_id inports outport = hwfunc
+  -- Expand the expression into an architecture body
   (signal_decls, statements, arg_signals, res_signal) <- expandExpr [] expr
   let inport_assigns = concat $ zipWith createSignalAssignments arg_signals inports
   let outport_assigns = createSignalAssignments outport res_signal
@@ -481,10 +494,10 @@ data HWFunction = HWFunction { -- A function that is available in hardware
 -- output ports.
 mkHWFunction ::
   CoreBind                                   -- The core binder to generate the interface for
-  -> VHDLState (HsFunction, HWFunction)          -- The name of the function and its interface
+  -> VHDLState HWFunction                    -- The function interface
 
 mkHWFunction (NonRec var expr) =
-    return (hsfunc, HWFunction (mkVHDLId name) inports outport)
+    return $ HWFunction (mkVHDLId name) inports outport
   where
     name = getOccString var
     ty = CoreUtils.exprType expr
@@ -496,7 +509,6 @@ mkHWFunction (NonRec var expr) =
       [port] -> [getPortNameMapForTy "portin" port]
       ps     -> getPortNameMapForTys "portin" 0 ps
     outport = getPortNameMapForTy "portout" res
-    hsfunc = HsFunction name [] (Tuple [])
 
 mkHWFunction (Rec _) =
   error "Recursive binders not supported"
@@ -505,6 +517,7 @@ mkHWFunction (Rec _) =
 -- return value) used?
 data HsValueUse = 
   Port -- ^ Use it as a port (input or output)
+  | State --- ^ Use it as state (input or output)
   deriving (Show, Eq)
 
 -- | This type describes a particular use of a Haskell function and is used to
@@ -531,6 +544,36 @@ appToHsFunction f args ty =
     hsres  = mkHsValueMap mkPort ty
     hsname = getOccString f
 
+-- | Translate a top level function declaration to a HsFunction. i.e., which
+--   interface will be provided by this function. This function essentially
+--   defines the "calling convention" for hardware models.
+mkHsFunction ::
+  Var.Var         -- ^ The function defined
+  -> Type         -- ^ The function type (including arguments!)
+  -> HsFunction   -- ^ The resulting HsFunction
+
+mkHsFunction f ty =
+  HsFunction hsname hsargs hsres
+  where
+    mkPort = mkHsValueMap (\x -> Single Port)
+    mkState = mkHsValueMap (\x -> Single State)
+    hsname  = getOccString f
+    (arg_tys, res_ty) = Type.splitFunTys ty
+    -- The last argument must be state
+    state_ty = last arg_tys
+    state    = mkState state_ty
+    -- All but the last argument are inports
+    inports = map mkPort (init arg_tys)
+    hsargs   = inports ++ [state]
+    hsres    = case splitTupleType res_ty of
+      -- Result type must be a two tuple (state, ports)
+      Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
+        then
+          Tuple [state, mkPort outport_ty]
+        else
+          error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
+      otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
+
 data VHDLSession = VHDLSession {
   nameCount :: Int,                       -- A counter that can be used to generate unique names
   funcs     :: [(HsFunction, HWFunction)] -- All functions available
@@ -553,6 +596,21 @@ getHWFunc hsfunc = do
     (error $ "Function " ++ (hsName hsfunc) ++ "is unknown? This should not happen!")
     (lookup hsfunc fs)
 
+-- | Splits a tuple type into a list of element types, or Nothing if the type
+--   is not a tuple type.
+splitTupleType ::
+  Type              -- ^ The type to split
+  -> Maybe [Type]   -- ^ The tuples element types
+
+splitTupleType ty =
+  case Type.splitTyConApp_maybe ty of
+    Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
+      then
+        Just args
+      else
+        Nothing
+    Nothing -> Nothing
+
 -- Makes the given name unique by appending a unique number.
 -- This does not do any checking against existing names, so it only guarantees
 -- uniqueness with other names generated by uniqueName.