Add empty let removal normalization pass.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 9828d5ceea96704f92b979766dd06be5bfcc7523..8cbebad6978b29b687fbc4e6f1f133e2ac145b42 100644 (file)
@@ -107,7 +107,7 @@ castsimpl expr@(Cast val ty) = do
       -- Generate a binder for the expression
       id <- Trans.lift $ mkBinderFor val "castval"
       -- Extract the expression
-      change $ Let (Rec [(id, val)]) (Cast (Var id) ty)
+      change $ Let (NonRec id val) (Cast (Var id) ty)
     else
       return expr
 -- Leave all other expressions unchanged
@@ -146,9 +146,12 @@ letderectop = everywhere ("letderec", letderec)
 -- let simplification
 --------------------------------
 letsimpl, letsimpltop :: Transform
+-- Don't simplify a let that evaluates to another let, since this is already
+-- normal form (and would cause infinite loops with letflat below).
+letsimpl expr@(Let _ (Let _ _)) = return expr
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
-letsimpl expr@(Let (Rec binds) res) = do
+letsimpl expr@(Let binds res) = do
   repr <- isRepr res
   local_var <- Trans.lift $ is_local_var res
   if not local_var && repr
@@ -156,8 +159,7 @@ letsimpl expr@(Let (Rec binds) res) = do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
       id <- Trans.lift $ mkBinderFor res "foo"
-      let bind = (id, res)
-      change $ Let (Rec (bind:binds)) (Var id)
+      change $ Let binds (Let (NonRec id  res) (Var id))
     else
       -- If the result is already a local var, don't extract it.
       return expr
@@ -170,34 +172,36 @@ letsimpltop = everywhere ("letsimpl", letsimpl)
 --------------------------------
 -- let flattening
 --------------------------------
+-- Takes a let that binds another let, and turns that into two nested lets.
+-- e.g., from:
+-- let b = (let b' = expr' in res') in res
+-- to:
+-- let b' = expr' in (let b = res' in res)
 letflat, letflattop :: Transform
-letflat (Let (Rec binds) expr) = do
-  -- Turn each binding into a list of bindings (possibly containing just one
-  -- element, of course)
-  bindss <- Monad.mapM flatbind binds
-  -- Concat all the bindings
-  let binds' = concat bindss
-  -- Return the new let. We don't use change here, since possibly nothing has
-  -- changed. If anything has changed, flatbind has already flagged that
-  -- change.
-  return $ Let (Rec binds') expr
-  where
-    -- Turns a binding of a let into a multiple bindings, or any other binding
-    -- into a list with just that binding
-    flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
-    flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
-    flatbind (b, expr) = return [(b, expr)]
+letflat (Let (NonRec b (Let (NonRec b' expr')  res')) res) = 
+  change $ Let (NonRec b' expr') (Let (NonRec b res') res)
 -- Leave all other expressions unchanged
 letflat expr = return expr
 -- Perform this transform everywhere
 letflattop = everywhere ("letflat", letflat)
 
+--------------------------------
+-- empty let removal
+--------------------------------
+-- Remove empty (recursive) lets
+letremove, letremovetop :: Transform
+letremove (Let (Rec []) res) = change $ res
+-- Leave all other expressions unchanged
+letremove expr = return expr
+-- Perform this transform everywhere
+letremovetop = everywhere ("letremove", letremove)
+
 --------------------------------
 -- Simple let binding removal
 --------------------------------
 -- Remove a = b bindings from let expressions everywhere
-letremovetop :: Transform
-letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
+letremovesimpletop :: Transform
+letremovesimpletop = everywhere ("letremovesimple", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
 
 --------------------------------
 -- Unused let binding removal
@@ -280,7 +284,7 @@ scrutsimpl expr@(Case scrut b ty alts) = do
   if repr
     then do
       id <- Trans.lift $ mkBinderFor scrut "scrut"
-      change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
+      change $ Let (NonRec id scrut) (Case (Var id) b ty alts)
     else
       return expr
 -- Leave all other expressions unchanged
@@ -414,7 +418,7 @@ appsimpl expr@(App f arg) = do
   if repr && not local_var
     then do -- Extract representable arguments
       id <- Trans.lift $ mkBinderFor arg "arg"
-      change $ Let (Rec [(id, arg)]) (App f (Var id))
+      change $ Let (NonRec id arg) (App f (Var id))
     else -- Leave non-representable arguments unchanged
       return expr
 -- Leave all other expressions unchanged
@@ -563,7 +567,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letderectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -587,15 +591,11 @@ normalizeExpr ::
   -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression
 
 normalizeExpr what expr = do
-      -- Introduce an empty Let at the top level, so there will always be
-      -- a let in the expression (none of the transformations will remove
-      -- the last let).
-      let expr' = Let (Rec []) expr
       -- Normalize this expression
-      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
-      expr'' <- dotransforms transforms expr'
-      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
-      return expr''
+      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr ) ++ "\n") $ return ()
+      expr' <- dotransforms transforms expr
+      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+      return expr'
 
 -- | Get the value that is bound to the given binder at top level. Fails when
 --   there is no such binding.