Revert "Don't generate VHDL for state packing."
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 647bd1902592c46db67c132a9ff744901a82b971..ec3ed56bde3da6b606a1855e0dfdd34d50bca922 100644 (file)
@@ -90,6 +90,30 @@ castprop expr = return expr
 -- Perform this transform everywhere
 castproptop = everywhere ("castprop", castprop)
 
+--------------------------------
+-- Cast simplification. Mostly useful for state packing and unpacking, but
+-- perhaps for others as well.
+--------------------------------
+castsimpl, castsimpltop :: Transform
+castsimpl expr@(Cast val ty) = do
+  -- Don't extract values that are already simpl
+  local_var <- Trans.lift $ is_local_var val
+  -- Don't extract values that are not representable, to prevent loops with
+  -- inlinenonrep
+  repr <- isRepr val
+  if (not local_var) && repr
+    then do
+      -- Generate a binder for the expression
+      id <- Trans.lift $ mkBinderFor val "castval"
+      -- Extract the expression
+      change $ Let (Rec [(id, val)]) (Cast (Var id) ty)
+    else
+      return expr
+-- Leave all other expressions unchanged
+castsimpl expr = return expr
+-- Perform this transform everywhere
+castsimpltop = everywhere ("castsimpl", castsimpl)
+
 --------------------------------
 -- let recursification
 --------------------------------
@@ -113,7 +137,7 @@ letsimpl expr@(Let (Rec binds) res) = do
     then do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
-      id <- Trans.lift $ mkInternalVar "foo" (CoreUtils.exprType res)
+      id <- Trans.lift $ mkBinderFor res "foo"
       let bind = (id, res)
       change $ Let (Rec (bind:binds)) (Var id)
     else
@@ -170,11 +194,41 @@ letremoveunused expr@(Let (Rec binds) res) = do
       bound_exprs = map snd binds
       -- For each bind check if the bind is used by res or any of the bound
       -- expressions
-      dobind (bndr, _) = not $ any (expr_uses_binders [bndr]) (res:bound_exprs)
+      dobind (bndr, _) = any (expr_uses_binders [bndr]) (res:bound_exprs)
 -- Leave all other expressions unchanged
 letremoveunused expr = return expr
 letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
 
+--------------------------------
+-- Identical let binding merging
+--------------------------------
+-- Merge two bindings in a let if they are identical 
+-- TODO: We would very much like to use GHC's CSE module for this, but that
+-- doesn't track if something changed or not, so we can't use it properly.
+letmerge, letmergetop :: Transform
+letmerge expr@(Let (Rec binds) res) = do
+  binds' <- domerge binds
+  return (Let (Rec binds') res)
+  where
+    domerge :: [(CoreBndr, CoreExpr)] -> TransformMonad [(CoreBndr, CoreExpr)]
+    domerge [] = return []
+    domerge (e:es) = do 
+      es' <- mapM (mergebinds e) es
+      es'' <- domerge es'
+      return (e:es'')
+
+    -- Uses the second bind to simplify the second bind, if applicable.
+    mergebinds :: (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
+    mergebinds (b1, e1) (b2, e2)
+      -- Identical expressions? Replace the second binding with a reference to
+      -- the first binder.
+      | CoreUtils.cheapEqExpr e1 e2 = change $ (b2, Var b1)
+      -- Different expressions? Don't change
+      | otherwise = return (b2, e2)
+-- Leave all other expressions unchanged
+letmerge expr = return expr
+letmergetop = everywhere ("letmerge", letmerge)
+    
 --------------------------------
 -- Function inlining
 --------------------------------
@@ -207,7 +261,7 @@ scrutsimpl expr@(Case scrut b ty alts) = do
   repr <- isRepr scrut
   if repr
     then do
-      id <- Trans.lift $ mkInternalVar "scrut" (CoreUtils.exprType scrut)
+      id <- Trans.lift $ mkBinderFor scrut "scrut"
       change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
     else
       return expr
@@ -301,7 +355,7 @@ casesimpl expr@(Case scrut b ty alts) = do
         -- prevent loops with inlinenonrep).
         if (not uses_bndrs) && (not local_var) && repr
           then do
-            id <- Trans.lift $ mkInternalVar "caseval" (CoreUtils.exprType expr)
+            id <- Trans.lift $ mkBinderFor expr "caseval"
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
             return $ (Just (id, expr), Var id)
@@ -341,7 +395,7 @@ appsimpl expr@(App f arg) = do
   local_var <- Trans.lift $ is_local_var arg
   if repr && not local_var
     then do -- Extract representable arguments
-      id <- Trans.lift $ mkInternalVar "arg" (CoreUtils.exprType arg)
+      id <- Trans.lift $ mkBinderFor arg "arg"
       change $ Let (Rec [(id, arg)]) (App f (Var id))
     else -- Leave non-representable arguments unchanged
       return expr
@@ -491,7 +545,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -520,9 +574,9 @@ normalizeExpr what expr = do
       -- the last let).
       let expr' = Let (Rec []) expr
       -- Normalize this expression
-      trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
+      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
       expr'' <- dotransforms transforms expr'
-      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
       return expr''
 
 -- | Get the value that is bound to the given binder at top level. Fails when