Make listBind support recursive bindings.
[matthijs/master-project/cλash.git] / Translator.hs
index 1ecbdb9223b3c4267649f4705e1ffa000b358de0..1ce9307d72992452d5bfe8415e667d3823c46c59 100644 (file)
@@ -8,6 +8,9 @@ import qualified Var
 import qualified Type
 import qualified TyCon
 import qualified DataCon
+import qualified HscMain
+import qualified SrcLoc
+import qualified FastString
 import qualified Maybe
 import qualified Module
 import qualified Data.Foldable as Foldable
@@ -46,7 +49,7 @@ import VHDLTypes
 import qualified VHDL
 
 main = do
-  makeVHDL "Alu.hs" "register_bank" True
+  makeVHDL "Alu.hs" "exec" True
 
 makeVHDL :: String -> String -> Bool -> IO ()
 makeVHDL filename name stateful = do
@@ -63,11 +66,13 @@ makeVHDL filename name stateful = do
 listBind :: String -> String -> IO ()
 listBind filename name = do
   core <- loadModule filename
-  let binds = findBinds core [name]
+  let [(b, expr)] = findBinds core [name]
   putStr "\n"
-  putStr $ prettyShow binds
+  putStr $ prettyShow expr
   putStr "\n\n"
-  putStr $ showSDoc $ ppr binds
+  putStr $ showSDoc $ ppr expr
+  putStr "\n\n"
+  putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
   putStr "\n\n"
 
 -- | Translate the binds with the given names from the given core module to
@@ -86,7 +91,7 @@ moduleToVHDL core list = do
   return vhdl
   where
     -- Turns the given bind into VHDL
-    mkVHDL :: [CoreBind] -> [Bool] -> TranslatorState [(AST.VHDLId, AST.DesignFile)]
+    mkVHDL :: [(CoreBndr, CoreExpr)] -> [Bool] -> TranslatorState [(AST.VHDLId, AST.DesignFile)]
     mkVHDL binds statefuls = do
       -- Add the builtin functions
       --mapM addBuiltIn builtin_funcs
@@ -125,28 +130,24 @@ loadModule filename =
       return core
 
 -- | Extracts the named binds from the given module.
-findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
-findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
+findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
+findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
 
 -- | Extract a named bind from the given list of binds
-findBind :: [CoreBind] -> String -> Maybe CoreBind
+findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
 findBind binds lookfor =
   -- This ignores Recs and compares the name of the bind with lookfor,
   -- disregarding any namespaces in OccName and extra attributes in Name and
   -- Var.
-  find (\b -> case b of 
-    Rec l -> False
-    NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
-  ) binds
+  find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
 
 -- | Processes the given bind as a top level bind.
 processBind ::
   Bool                       -- ^ Should this be stateful function?
-  -> CoreBind                -- ^ The bind to process
+  -> (CoreBndr, CoreExpr)    -- ^ The bind to process
   -> TranslatorState ()
 
-processBind _ (Rec _) = error "Recursive binders not supported"
-processBind stateful bind@(NonRec var expr) = do
+processBind stateful bind@(var, expr) = do
   -- Create the function signature
   let ty = CoreUtils.exprType expr
   let hsfunc = mkHsFunction var ty stateful
@@ -157,18 +158,16 @@ processBind stateful bind@(NonRec var expr) = do
 --   with them.
 flattenBind ::
   HsFunction                         -- The signature to flatten into
-  -> CoreBind                        -- The bind to flatten
+  -> (CoreBndr, CoreExpr)            -- The bind to flatten
   -> TranslatorState ()
 
-flattenBind _ (Rec _) = error "Recursive binders not supported"
-
-flattenBind hsfunc bind@(NonRec var expr) = do
+flattenBind hsfunc bind@(var, expr) = do
   -- Flatten the function
   let flatfunc = flattenFunction hsfunc bind
   -- Propagate state variables
   let flatfunc' = propagateState hsfunc flatfunc
   -- Store the flat function in the session
-  modA tsFlatFuncs (Map.insert hsfunc flatfunc)
+  modA tsFlatFuncs (Map.insert hsfunc flatfunc')
   -- Flatten any functions used
   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
   mapM_ resolvFunc used_hsfuncs
@@ -268,14 +267,15 @@ resolvFunc ::
 
 resolvFunc hsfunc = do
   flatfuncmap <- getA tsFlatFuncs
-  -- Don't do anything if there is already a flat function for this hsfunc.
+  -- Don't do anything if there is already a flat function for this hsfunc or
+  -- when it is a builtin function.
   Monad.unless (Map.member hsfunc flatfuncmap) $ do
-  -- TODO: Builtin functions
+  Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
   -- New function, resolve it
   core <- getA tsCoreModule
   -- Find the named function
   let name = (hsFuncName hsfunc)
-  let bind = findBind (cm_binds core) name 
+  let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
   case bind of
     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
     Just b  -> flattenBind hsfunc b
@@ -352,33 +352,4 @@ splitTupleType ty =
         Nothing
     Nothing -> Nothing
 
--- | A consise representation of a (set of) ports on a builtin function
-type PortMap = HsValueMap (String, AST.TypeMark)
--- | A consise representation of a builtin function
-data BuiltIn = BuiltIn String [PortMap] PortMap
-
--- | Map a port specification of a builtin function to a VHDL Signal to put in
---   a VHDLSignalMap
-toVHDLSignalMap :: HsValueMap (String, AST.TypeMark) -> VHDLSignalMap
-toVHDLSignalMap = fmap (\(name, ty) -> Just (VHDL.mkVHDLId name, ty))
-
--- | Translate a concise representation of a builtin function to something
---   that can be put into FuncMap directly.
-{-
-addBuiltIn :: BuiltIn -> TranslatorState ()
-addBuiltIn (BuiltIn name args res) = do
-    addFunc hsfunc
-    setEntity hsfunc entity
-  where
-    hsfunc = HsFunction name (map useAsPort args) (useAsPort res)
-    entity = Entity (VHDL.mkVHDLId name) (map toVHDLSignalMap args) (toVHDLSignalMap res) Nothing Nothing
-
-builtin_funcs = 
-  [ 
-    BuiltIn "hwxor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
-    BuiltIn "hwand" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
-    BuiltIn "hwor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
-    BuiltIn "hwnot" [(Single ("a", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty))
-  ]
--}
 -- vim: set ts=8 sw=2 sts=2 expandtab: