Let mkHType also return errors using Either.
[matthijs/master-project/cλash.git] / Generate.hs
index 55f015608de743e03540a18e52633d66dfbead5f..b3045def5bfcee16e869fac00e08e603e7335e15 100644 (file)
@@ -6,6 +6,7 @@ import qualified Data.Map as Map
 import qualified Maybe
 import qualified Data.Either as Either
 import Data.Accessor
+import Data.Accessor.MonadState as MonadState
 import Debug.Trace
 
 -- ForSyDe
@@ -71,16 +72,16 @@ genOperator1' op _ f [arg] = return $ op arg
 
 -- | Generate a function call from the destination binder, function name and a
 -- list of expressions (its arguments)
-genFCall :: BuiltinBuilder 
-genFCall = genExprArgs $ genExprRes genFCall'
-genFCall' :: Either CoreSyn.CoreBndr AST.VHDLName -> CoreSyn.CoreBndr -> [AST.Expr] -> VHDLSession AST.Expr
-genFCall' (Left res) f args = do
+genFCall :: Bool -> BuiltinBuilder 
+genFCall switch = genExprArgs $ genExprRes (genFCall' switch)
+genFCall' :: Bool -> Either CoreSyn.CoreBndr AST.VHDLName -> CoreSyn.CoreBndr -> [AST.Expr] -> VHDLSession AST.Expr
+genFCall' switch (Left res) f args = do
   let fname = varToString f
-  let el_ty = (tfvec_elem . Var.varType) res
-  id <- vectorFunId el_ty fname
+  let el_ty = if switch then (Var.varType res) else ((tfvec_elem . Var.varType) res)
+  id <- MonadState.lift vsType $ vectorFunId el_ty fname
   return $ AST.PrimFCall $ AST.FCall (AST.NSimple id)  $
              map (\exp -> Nothing AST.:=>: AST.ADExpr exp) args
-genFCall' (Right name) _ _ = error $ "\nGenerate.genFCall': Cannot generate builtin function call assigned to a VHDLName: " ++ show name
+genFCall' (Right name) _ _ = error $ "\nGenerate.genFCall': Cannot generate builtin function call assigned to a VHDLName: " ++ show name
 
 -- | Generate a generate statement for the builtin function "map"
 genMap :: BuiltinBuilder
@@ -155,7 +156,7 @@ genFold' left (Left res) f [folded_f, start, vec] = do
   -- temporary vector
   let tmp_ty = Type.mkAppTy nvec (Var.varType start)
   let error_msg = "\nGenerate.genFold': Can not construct temp vector for element type: " ++ pprString tmp_ty 
-  tmp_vhdl_ty <- vhdl_ty error_msg tmp_ty
+  tmp_vhdl_ty <- MonadState.lift vsType $ vhdl_ty error_msg tmp_ty
   -- Setup the generate scheme
   let gen_label = mkVHDLExtId ("foldlVector" ++ (varToString vec))
   let block_label = mkVHDLExtId ("foldlVector" ++ (varToString start))
@@ -245,7 +246,7 @@ genZip' (Left res) f args@[arg1, arg2] =
     argexpr1        = vhdlNameToVHDLExpr $ mkIndexedName (varToVHDLName arg1) n_expr
     argexpr2        = vhdlNameToVHDLExpr $ mkIndexedName (varToVHDLName arg2) n_expr
   in do
-    labels <- getFieldLabels (tfvec_elem (Var.varType res))
+    labels <- MonadState.lift vsType $ getFieldLabels (tfvec_elem (Var.varType res))
     let resnameA    = mkSelectedName resname' (labels!!0)
     let resnameB    = mkSelectedName resname' (labels!!1)
     let resA_assign = mkUncondAssign (Right resnameA) argexpr1
@@ -270,8 +271,8 @@ genUnzip' (Left res) f args@[arg] =
     resname'        = varToVHDLName res
     argexpr'        = mkIndexedName (varToVHDLName arg) n_expr
   in do
-    reslabels <- getFieldLabels (Var.varType res)
-    arglabels <- getFieldLabels (tfvec_elem (Var.varType arg))
+    reslabels <- MonadState.lift vsType $ getFieldLabels (Var.varType res)
+    arglabels <- MonadState.lift vsType $ getFieldLabels (tfvec_elem (Var.varType arg))
     let resnameA    = mkIndexedName (mkSelectedName resname' (reslabels!!0)) n_expr
     let resnameB    = mkIndexedName (mkSelectedName resname' (reslabels!!1)) n_expr
     let argexprA    = vhdlNameToVHDLExpr $ mkSelectedName argexpr' (arglabels!!0)
@@ -346,7 +347,7 @@ genIterateOrGenerate' iter (Left res) f [app_f, start] = do
   -- -- temporary vector
   let tmp_ty = Var.varType res
   let error_msg = "\nGenerate.genFold': Can not construct temp vector for element type: " ++ pprString tmp_ty 
-  tmp_vhdl_ty <- vhdl_ty error_msg tmp_ty
+  tmp_vhdl_ty <- MonadState.lift vsType $ vhdl_ty error_msg tmp_ty
   -- Setup the generate scheme
   let gen_label = mkVHDLExtId ("iterateVector" ++ (varToString start))
   let block_label = mkVHDLExtId ("iterateVector" ++ (varToString res))
@@ -420,7 +421,7 @@ genApplication dst f args =
       -- It's a datacon. Create a record from its arguments.
       Left bndr -> do
         -- We have the bndr, so we can get at the type
-        labels <- getFieldLabels (Var.varType bndr)
+        labels <- MonadState.lift vsType $ getFieldLabels (Var.varType bndr)
         return $ zipWith mkassign labels $ map (either exprToVHDLExpr id) args
         where
           mkassign :: AST.VHDLId -> AST.Expr -> AST.ConcSm
@@ -464,7 +465,7 @@ genApplication dst f args =
 
 -- Returns the VHDLId of the vector function with the given name for the given
 -- element type. Generates -- this function if needed.
-vectorFunId :: Type.Type -> String -> VHDLSession AST.VHDLId
+vectorFunId :: Type.Type -> String -> TypeSession AST.VHDLId
 vectorFunId el_ty fname = do
   let error_msg = "\nGenerate.vectorFunId: Can not construct vector function for element: " ++ pprString el_ty
   elemTM <- vhdl_ty error_msg el_ty
@@ -874,40 +875,40 @@ genUnconsVectorFuns elemTM vectorTM  =
 -- builder function.
 globalNameTable :: NameTable
 globalNameTable = Map.fromList
-  [ (exId             , (2, genFCall                ) )
-  , (replaceId        , (3, genFCall                ) )
-  , (headId           , (1, genFCall                ) )
-  , (lastId           , (1, genFCall                ) )
-  , (tailId           , (1, genFCall                ) )
-  , (initId           , (1, genFCall                ) )
-  , (takeId           , (2, genFCall                ) )
-  , (dropId           , (2, genFCall                ) )
-  , (selId            , (4, genFCall                ) )
-  , (plusgtId         , (2, genFCall                ) )
-  , (ltplusId         , (2, genFCall                ) )
-  , (plusplusId       , (2, genFCall                ) )
+  [ (exId             , (2, genFCall False          ) )
+  , (replaceId        , (3, genFCall False          ) )
+  , (headId           , (1, genFCall True           ) )
+  , (lastId           , (1, genFCall True           ) )
+  , (tailId           , (1, genFCall False          ) )
+  , (initId           , (1, genFCall False          ) )
+  , (takeId           , (2, genFCall False          ) )
+  , (dropId           , (2, genFCall False          ) )
+  , (selId            , (4, genFCall False          ) )
+  , (plusgtId         , (2, genFCall False          ) )
+  , (ltplusId         , (2, genFCall False          ) )
+  , (plusplusId       , (2, genFCall False          ) )
   , (mapId            , (2, genMap                  ) )
   , (zipWithId        , (3, genZipWith              ) )
   , (foldlId          , (3, genFoldl                ) )
   , (foldrId          , (3, genFoldr                ) )
   , (zipId            , (2, genZip                  ) )
   , (unzipId          , (1, genUnzip                ) )
-  , (shiftlId         , (2, genFCall                ) )
-  , (shiftrId         , (2, genFCall                ) )
-  , (rotlId           , (1, genFCall                ) )
-  , (rotrId           , (1, genFCall                ) )
+  , (shiftlId         , (2, genFCall False          ) )
+  , (shiftrId         , (2, genFCall False          ) )
+  , (rotlId           , (1, genFCall False          ) )
+  , (rotrId           , (1, genFCall False          ) )
   , (concatId         , (1, genConcat               ) )
-  , (reverseId        , (1, genFCall                ) )
+  , (reverseId        , (1, genFCall False          ) )
   , (iteratenId       , (3, genIteraten             ) )
   , (iterateId        , (2, genIterate              ) )
   , (generatenId      , (3, genGeneraten            ) )
   , (generateId       , (2, genGenerate             ) )
-  , (emptyId          , (0, genFCall                ) )
-  , (singletonId      , (1, genFCall                ) )
-  , (copynId          , (2, genFCall                ) )
+  , (emptyId          , (0, genFCall False          ) )
+  , (singletonId      , (1, genFCall False          ) )
+  , (copynId          , (2, genFCall False          ) )
   , (copyId           , (1, genCopy                 ) )
-  , (lengthTId        , (1, genFCall                ) )
-  , (nullId           , (1, genFCall                ) )
+  , (lengthTId        , (1, genFCall False          ) )
+  , (nullId           , (1, genFCall False          ) )
   , (hwxorId          , (2, genOperator2 AST.Xor    ) )
   , (hwandId          , (2, genOperator2 AST.And    ) )
   , (hworId           , (2, genOperator2 AST.Or     ) )