Ignore casts that just repack state. Don't make VHDL for them, their type is empty
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 6000bb127832d380bb84b39981cf3c34b2e97e90..1d40f92d52f7b94989fc9215bcfd15723a9001e0 100644 (file)
@@ -4,7 +4,7 @@
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
-module CLasH.Normalize (getNormalized, normalizeExpr) where
+module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where
 
 -- Standard modules
 import Debug.Trace
@@ -137,7 +137,7 @@ letsimpl expr@(Let (Rec binds) res) = do
     then do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
-      id <- Trans.lift $ mkInternalVar "foo" (CoreUtils.exprType res)
+      id <- Trans.lift $ mkBinderFor res "foo"
       let bind = (id, res)
       change $ Let (Rec (bind:binds)) (Var id)
     else
@@ -261,7 +261,7 @@ scrutsimpl expr@(Case scrut b ty alts) = do
   repr <- isRepr scrut
   if repr
     then do
-      id <- Trans.lift $ mkInternalVar "scrut" (CoreUtils.exprType scrut)
+      id <- Trans.lift $ mkBinderFor scrut "scrut"
       change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
     else
       return expr
@@ -355,7 +355,7 @@ casesimpl expr@(Case scrut b ty alts) = do
         -- prevent loops with inlinenonrep).
         if (not uses_bndrs) && (not local_var) && repr
           then do
-            id <- Trans.lift $ mkInternalVar "caseval" (CoreUtils.exprType expr)
+            id <- Trans.lift $ mkBinderFor expr "caseval"
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
             return $ (Just (id, expr), Var id)
@@ -395,7 +395,7 @@ appsimpl expr@(App f arg) = do
   local_var <- Trans.lift $ is_local_var arg
   if repr && not local_var
     then do -- Extract representable arguments
-      id <- Trans.lift $ mkInternalVar "arg" (CoreUtils.exprType arg)
+      id <- Trans.lift $ mkBinderFor arg "arg"
       change $ Let (Rec [(id, arg)]) (App f (Var id))
     else -- Leave non-representable arguments unchanged
       return expr
@@ -574,9 +574,9 @@ normalizeExpr what expr = do
       -- the last let).
       let expr' = Let (Rec []) expr
       -- Normalize this expression
-      trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
+      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
       expr'' <- dotransforms transforms expr'
-      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
+      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
       return expr''
 
 -- | Get the value that is bound to the given binder at top level. Fails when
@@ -589,3 +589,31 @@ getBinding bndr = Utils.makeCached bndr tsBindings $ do
   -- If the binding isn't in the "cache" (bindings map), then we can't create
   -- it out of thin air, so return an error.
   error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
+
+-- | Split a normalized expression into the argument binders, top level
+--   bindings and the result binder.
+splitNormalized ::
+  CoreExpr -- ^ The normalized expression
+  -> ([CoreBndr], [Binding], CoreBndr)
+splitNormalized expr = (args, binds, res)
+  where
+    (args, letexpr) = CoreSyn.collectBinders expr
+    (binds, resexpr) = flattenLets letexpr
+    res = case resexpr of 
+      (Var x) -> x
+      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- | Flattens nested lets into a single list of bindings. The expression
+--   passed does not have to be a let expression, if it isn't an empty list of
+--   bindings is returned.
+flattenLets ::
+  CoreExpr -- ^ The expression to flatten.
+  -> ([Binding], CoreExpr) -- ^ The bindings and resulting expression.
+flattenLets (Let binds expr) = 
+  (bindings ++ bindings', expr')
+  where
+    -- Recursively flatten the contained expression
+    (bindings', expr') =flattenLets expr
+    -- Flatten our own bindings to remove the Rec / NonRec constructors
+    bindings = CoreSyn.flattenBinds [binds]
+flattenLets expr = ([], expr)