Work around some bugs in the current clash to make reducer compile correctly
[matthijs/master-project/cλash.git] / cλash / CLasH / Utils / Core / CoreTools.hs
index 7377c9e36ce332e14c04f2a6cde234e47bf8eb95..094b70294ceabcca23e722df69c98e93e1bd85db 100644 (file)
@@ -23,23 +23,30 @@ import qualified OccName
 import qualified Type
 import qualified Id
 import qualified TyCon
 import qualified Type
 import qualified Id
 import qualified TyCon
+import qualified DataCon
 import qualified TysWiredIn
 import qualified Bag
 import qualified DynFlags
 import qualified SrcLoc
 import qualified CoreSyn
 import qualified Var
 import qualified TysWiredIn
 import qualified Bag
 import qualified DynFlags
 import qualified SrcLoc
 import qualified CoreSyn
 import qualified Var
+import qualified IdInfo
 import qualified VarSet
 import qualified Unique
 import qualified CoreUtils
 import qualified CoreFVs
 import qualified Literal
 import qualified VarSet
 import qualified Unique
 import qualified CoreUtils
 import qualified CoreFVs
 import qualified Literal
+import qualified MkCore
 
 -- Local imports
 
 -- Local imports
+import CLasH.Translator.TranslatorTypes
 import CLasH.Utils.GhcTools
 import CLasH.Utils.HsTools
 import CLasH.Utils.Pretty
 
 import CLasH.Utils.GhcTools
 import CLasH.Utils.HsTools
 import CLasH.Utils.Pretty
 
+-- | A single binding, used as a shortcut to simplify type signatures.
+type Binding = (CoreSyn.CoreBndr, CoreSyn.CoreExpr)
+
 -- | Evaluate a core Type representing type level int from the tfp
 -- library to a real int.
 eval_tfp_int :: HscTypes.HscEnv -> Type.Type -> Int
 -- | Evaluate a core Type representing type level int from the tfp
 -- library to a real int.
 eval_tfp_int :: HscTypes.HscEnv -> Type.Type -> Int
@@ -222,16 +229,52 @@ getLiterals app@(CoreSyn.App _ _) = literals
 
 getLiterals lit@(CoreSyn.Lit _) = [lit]
 
 
 getLiterals lit@(CoreSyn.Lit _) = [lit]
 
-reduceCoreListToHsList :: CoreSyn.CoreExpr -> [CoreSyn.CoreExpr]
-reduceCoreListToHsList app@(CoreSyn.App _ _) = out
+reduceCoreListToHsList :: 
+  [HscTypes.CoreModule] -- ^ The modules where parts of the list are hidden
+  -> CoreSyn.CoreExpr   -- ^ The refence to atleast one of the nodes
+  -> TranslatorSession [CoreSyn.CoreExpr]
+reduceCoreListToHsList cores app@(CoreSyn.App _ _) = do {
+  ; let { (fun, args) = CoreSyn.collectArgs app
+        ; len         = length args 
+        } ;
+  ; case len of
+      3 -> do {
+        ; let topelem = args!!1
+        ; case (args!!2) of
+            (varz@(CoreSyn.Var id)) -> do {
+              ; binds <- mapM (findExpr (isVarName id)) cores
+              ; otherelems <- reduceCoreListToHsList cores (head (Maybe.catMaybes binds))
+              ; return (topelem:otherelems)
+              }
+            (appz@(CoreSyn.App _ _)) -> do {
+              ; otherelems <- reduceCoreListToHsList cores appz
+              ; return (topelem:otherelems)
+              }
+            otherwise -> return [topelem]
+        }
+      otherwise -> return []
+  }
   where
   where
-    (fun, args) = CoreSyn.collectArgs app
-    len = length args
-    out = case len of
-          3 -> ((args!!1) : (reduceCoreListToHsList (args!!2)))
-          otherwise -> []
-
-reduceCoreListToHsList _ = []
+    isVarName :: Monad m => Var.Var -> Var.Var -> m Bool
+    isVarName lookfor bind = return $ (Var.varName lookfor) == (Var.varName bind)
+
+reduceCoreListToHsList _ _ = return []
+
+-- Is the given var the State data constructor?
+isStateCon :: Var.Var -> Bool
+isStateCon var = do
+  -- See if it is a DataConWrapId (not DataConWorkId, since State is a
+  -- newtype).
+  case Id.idDetails var of
+    IdInfo.DataConWrapId dc -> 
+      -- See if the datacon is the State datacon from the State type.
+      let tycon = DataCon.dataConTyCon dc
+          tyname = Name.getOccString tycon
+          dcname = Name.getOccString dc
+      in case (tyname, dcname) of
+        ("State", "State") -> True
+        _ -> False
+    _ -> False
 
 -- | Is the given type a State type?
 isStateType :: Type.Type -> Bool
 
 -- | Is the given type a State type?
 isStateType :: Type.Type -> Bool
@@ -256,6 +299,29 @@ hasStateType expr = case getType expr of
   Just ty -> isStateType ty
 
 
   Just ty -> isStateType ty
 
 
+-- | Flattens nested lets into a single list of bindings. The expression
+--   passed does not have to be a let expression, if it isn't an empty list of
+--   bindings is returned.
+flattenLets ::
+  CoreSyn.CoreExpr -- ^ The expression to flatten.
+  -> ([Binding], CoreSyn.CoreExpr) -- ^ The bindings and resulting expression.
+flattenLets (CoreSyn.Let binds expr) = 
+  (bindings ++ bindings', expr')
+  where
+    -- Recursively flatten the contained expression
+    (bindings', expr') =flattenLets expr
+    -- Flatten our own bindings to remove the Rec / NonRec constructors
+    bindings = CoreSyn.flattenBinds [binds]
+flattenLets expr = ([], expr)
+
+-- | Create bunch of nested non-recursive let expressions from the given
+-- bindings. The first binding is bound at the highest level (and thus
+-- available in all other bindings).
+mkNonRecLets :: [Binding] -> CoreSyn.CoreExpr -> CoreSyn.CoreExpr
+mkNonRecLets bindings expr = MkCore.mkCoreLets binds expr
+  where
+    binds = map (uncurry CoreSyn.NonRec) bindings
+
 -- | A class of things that (optionally) have a core Type. The type is
 -- optional, since Type expressions don't have a type themselves.
 class TypedThing t where
 -- | A class of things that (optionally) have a core Type. The type is
 -- optional, since Type expressions don't have a type themselves.
 class TypedThing t where