Make letflat work with non-recursive lets.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 6eee47b3d39bcd1391ba4d1aba49b5c0aa83ae84..07ded20bbb8d387c38372f749c9c79789d38cdbf 100644 (file)
@@ -148,7 +148,7 @@ letderectop = everywhere ("letderec", letderec)
 letsimpl, letsimpltop :: Transform
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
-letsimpl expr@(Let (Rec binds) res) = do
+letsimpl expr@(Let binds res) = do
   repr <- isRepr res
   local_var <- Trans.lift $ is_local_var res
   if not local_var && repr
@@ -156,8 +156,7 @@ letsimpl expr@(Let (Rec binds) res) = do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
       id <- Trans.lift $ mkBinderFor res "foo"
-      let bind = (id, res)
-      change $ Let (Rec (bind:binds)) (Var id)
+      change $ Let binds (Let (NonRec id  res) (Var id))
     else
       -- If the result is already a local var, don't extract it.
       return expr
@@ -170,23 +169,14 @@ letsimpltop = everywhere ("letsimpl", letsimpl)
 --------------------------------
 -- let flattening
 --------------------------------
+-- Takes a let that binds another let, and turns that into two nested lets.
+-- e.g., from:
+-- let b = (let b' = expr' in res') in res
+-- to:
+-- let b' = expr' in (let b = res' in res)
 letflat, letflattop :: Transform
-letflat (Let (Rec binds) expr) = do
-  -- Turn each binding into a list of bindings (possibly containing just one
-  -- element, of course)
-  bindss <- Monad.mapM flatbind binds
-  -- Concat all the bindings
-  let binds' = concat bindss
-  -- Return the new let. We don't use change here, since possibly nothing has
-  -- changed. If anything has changed, flatbind has already flagged that
-  -- change.
-  return $ Let (Rec binds') expr
-  where
-    -- Turns a binding of a let into a multiple bindings, or any other binding
-    -- into a list with just that binding
-    flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
-    flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
-    flatbind (b, expr) = return [(b, expr)]
+letflat (Let (NonRec b (Let (NonRec b' expr')  res')) res) = 
+  change $ Let (NonRec b' expr') (Let (NonRec b res') res)
 -- Leave all other expressions unchanged
 letflat expr = return expr
 -- Perform this transform everywhere