Make letmerge work with non-recursive lets.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 07ded20bbb8d387c38372f749c9c79789d38cdbf..0995dbd1fa9653708892d790e0fd4359c8f35579 100644 (file)
@@ -146,6 +146,9 @@ letderectop = everywhere ("letderec", letderec)
 -- let simplification
 --------------------------------
 letsimpl, letsimpltop :: Transform
+-- Don't simplify a let that evaluates to another let, since this is already
+-- normal form (and would cause infinite loops with letflat below).
+letsimpl expr@(Let _ (Let _ _)) = return expr
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
 letsimpl expr@(Let binds res) = do
@@ -182,12 +185,23 @@ letflat expr = return expr
 -- Perform this transform everywhere
 letflattop = everywhere ("letflat", letflat)
 
+--------------------------------
+-- empty let removal
+--------------------------------
+-- Remove empty (recursive) lets
+letremove, letremovetop :: Transform
+letremove (Let (Rec []) res) = change $ res
+-- Leave all other expressions unchanged
+letremove expr = return expr
+-- Perform this transform everywhere
+letremovetop = everywhere ("letremove", letremove)
+
 --------------------------------
 -- Simple let binding removal
 --------------------------------
 -- Remove a = b bindings from let expressions everywhere
-letremovetop :: Transform
-letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
+letremovesimpletop :: Transform
+letremovesimpletop = everywhere ("letremovesimple", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
 
 --------------------------------
 -- Unused let binding removal
@@ -214,9 +228,10 @@ letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
 -- TODO: We would very much like to use GHC's CSE module for this, but that
 -- doesn't track if something changed or not, so we can't use it properly.
 letmerge, letmergetop :: Transform
-letmerge expr@(Let (Rec binds) res) = do
+letmerge expr@(Let _ _) = do
+  let (binds, res) = flattenLets expr
   binds' <- domerge binds
-  return (Let (Rec binds') res)
+  return $ MkCore.mkCoreLets (map (uncurry NonRec) binds') res
   where
     domerge :: [(CoreBndr, CoreExpr)] -> TransformMonad [(CoreBndr, CoreExpr)]
     domerge [] = return []
@@ -553,7 +568,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letderectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -606,18 +621,3 @@ splitNormalized expr = (args, binds, res)
     res = case resexpr of 
       (Var x) -> x
       _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
-
--- | Flattens nested lets into a single list of bindings. The expression
---   passed does not have to be a let expression, if it isn't an empty list of
---   bindings is returned.
-flattenLets ::
-  CoreExpr -- ^ The expression to flatten.
-  -> ([Binding], CoreExpr) -- ^ The bindings and resulting expression.
-flattenLets (Let binds expr) = 
-  (bindings ++ bindings', expr')
-  where
-    -- Recursively flatten the contained expression
-    (bindings', expr') =flattenLets expr
-    -- Flatten our own bindings to remove the Rec / NonRec constructors
-    bindings = CoreSyn.flattenBinds [binds]
-flattenLets expr = ([], expr)