Fix the trace output of normalized functions.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 7571a6f3b0fbf21b33673fe59476bd217ab336ed..b2b4bd86f0080693d29f70183ef469083b15bdfc 100644 (file)
@@ -4,7 +4,7 @@
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
-module CLasH.Normalize (getNormalized) where
+module CLasH.Normalize (getNormalized, normalizeExpr) where
 
 -- Standard modules
 import Debug.Trace
@@ -39,6 +39,7 @@ import CLasH.Normalize.NormalizeTools
 import CLasH.VHDL.VHDLTypes
 import qualified CLasH.Utils as Utils
 import CLasH.Utils.Core.CoreTools
+import CLasH.Utils.Core.BinderTools
 import CLasH.Utils.Pretty
 
 --------------------------------
@@ -51,7 +52,7 @@ import CLasH.Utils.Pretty
 eta, etatop :: Transform
 eta expr | is_fun expr && not (is_lam expr) = do
   let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
-  id <- mkInternalVar "param" arg_ty
+  id <- Trans.lift $ mkInternalVar "param" arg_ty
   change (Lam id (App expr (Var id)))
 -- Leave all other expressions unchanged
 eta e = return e
@@ -112,7 +113,7 @@ letsimpl expr@(Let (Rec binds) res) = do
     then do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
-      id <- mkInternalVar "foo" (CoreUtils.exprType res)
+      id <- Trans.lift $ mkInternalVar "foo" (CoreUtils.exprType res)
       let bind = (id, res)
       change $ Let (Rec (bind:binds)) (Var id)
     else
@@ -156,6 +157,24 @@ letflattop = everywhere ("letflat", letflat)
 letremovetop :: Transform
 letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
 
+--------------------------------
+-- Unused let binding removal
+--------------------------------
+letremoveunused, letremoveunusedtop :: Transform
+letremoveunused expr@(Let (Rec binds) res) = do
+  -- Filter out all unused binds.
+  let binds' = filter dobind binds
+  -- Only set the changed flag if binds got removed
+  changeif (length binds' /= length binds) (Let (Rec binds') res)
+    where
+      bound_exprs = map snd binds
+      -- For each bind check if the bind is used by res or any of the bound
+      -- expressions
+      dobind (bndr, _) = not $ any (expr_uses_binders [bndr]) (res:bound_exprs)
+-- Leave all other expressions unchanged
+letremoveunused expr = return expr
+letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
+
 --------------------------------
 -- Function inlining
 --------------------------------
@@ -188,7 +207,7 @@ scrutsimpl expr@(Case scrut b ty alts) = do
   repr <- isRepr scrut
   if repr
     then do
-      id <- mkInternalVar "scrut" (CoreUtils.exprType scrut)
+      id <- Trans.lift $ mkInternalVar "scrut" (CoreUtils.exprType scrut)
       change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
     else
       return expr
@@ -262,7 +281,7 @@ casesimpl expr@(Case scrut b ty alts) = do
             -- Create on new binder that will actually capture a value in this
             -- case statement, and return it.
             let bty = (Id.idType b)
-            id <- mkInternalVar "sel" bty
+            id <- Trans.lift $ mkInternalVar "sel" bty
             let binders = take i wildbndrs ++ [id] ++ drop (i+1) wildbndrs
             let caseexpr = Case scrut b bty [(con, binders, Var id)]
             return (wildbndrs!!i, Just (b, caseexpr))
@@ -282,7 +301,7 @@ casesimpl expr@(Case scrut b ty alts) = do
         -- prevent loops with inlinenonrep).
         if (not uses_bndrs) && (not local_var) && repr
           then do
-            id <- mkInternalVar "caseval" (CoreUtils.exprType expr)
+            id <- Trans.lift $ mkInternalVar "caseval" (CoreUtils.exprType expr)
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
             return $ (Just (id, expr), Var id)
@@ -322,7 +341,7 @@ appsimpl expr@(App f arg) = do
   local_var <- Trans.lift $ is_local_var arg
   if repr && not local_var
     then do -- Extract representable arguments
-      id <- mkInternalVar "arg" (CoreUtils.exprType arg)
+      id <- Trans.lift $ mkInternalVar "arg" (CoreUtils.exprType arg)
       change $ Let (Rec [(id, arg)]) (App f (Var id))
     else -- Leave non-representable arguments unchanged
       return expr
@@ -358,7 +377,7 @@ argprop expr@(App _ _) | is_var fexpr = do
           -- the old body applied to some arguments.
           let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
           -- Create a new function with the same name but a new body
-          newf <- mkFunction f newbody
+          newf <- Trans.lift $ mkFunction f newbody
           -- Replace the original application with one of the new function to the
           -- new arguments.
           change $ MkCore.mkCoreApps (Var newf) newargs
@@ -404,7 +423,7 @@ argprop expr@(App _ _) | is_var fexpr = do
           -- Representable types will not be propagated, and arguments with free
           -- type variables will be propagated later.
           -- TODO: preserve original naming?
-          id <- mkBinderFor arg "param"
+          id <- Trans.lift $ mkBinderFor arg "param"
           -- Just pass the original argument to the new function, which binds it
           -- to a new id and just pass that new id to the old function body.
           return ([arg], [id], mkReferenceTo id) 
@@ -451,7 +470,7 @@ funextract expr@(App _ _) | is_var fexpr = do
       -- by the argument expression.
       let free_vars = VarSet.varSetElems $ CoreFVs.exprFreeVars arg
       let body = MkCore.mkCoreLams free_vars arg
-      id <- mkBinderFor body "fun"
+      id <- Trans.lift $ mkBinderFor body "fun"
       Trans.lift $ addGlobalBind id body
       -- Replace the argument with a reference to the new function, applied to
       -- all vars it uses.
@@ -472,7 +491,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -487,14 +506,23 @@ getNormalized bndr = Utils.makeCached bndr tsNormalized $ do
       error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
     else do
       expr <- getBinding bndr
+      normalizeExpr (show bndr) expr
+
+-- | Normalize an expression
+normalizeExpr ::
+  String -- ^ What are we normalizing? For debug output only.
+  -> CoreSyn.CoreExpr -- ^ The expression to normalize 
+  -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression
+
+normalizeExpr what expr = do
       -- Introduce an empty Let at the top level, so there will always be
       -- a let in the expression (none of the transformations will remove
       -- the last let).
       let expr' = Let (Rec []) expr
       -- Normalize this expression
-      trace ("Transforming " ++ (show bndr) ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
+      trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
       expr'' <- dotransforms transforms expr'
-      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
       return expr''
 
 -- | Get the value that is bound to the given binder at top level. Fails when