Make letflat work with non-recursive lets.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 7fb0dc235d2a60d8a3e3798b8a655e9f6f3218f7..07ded20bbb8d387c38372f749c9c79789d38cdbf 100644 (file)
@@ -9,6 +9,7 @@ module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
+import qualified List
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
@@ -106,7 +107,7 @@ castsimpl expr@(Cast val ty) = do
       -- Generate a binder for the expression
       id <- Trans.lift $ mkBinderFor val "castval"
       -- Extract the expression
       -- Generate a binder for the expression
       id <- Trans.lift $ mkBinderFor val "castval"
       -- Extract the expression
-      change $ Let (Rec [(id, val)]) (Cast (Var id) ty)
+      change $ Let (NonRec id val) (Cast (Var id) ty)
     else
       return expr
 -- Leave all other expressions unchanged
     else
       return expr
 -- Leave all other expressions unchanged
@@ -115,14 +116,31 @@ castsimpl expr = return expr
 castsimpltop = everywhere ("castsimpl", castsimpl)
 
 --------------------------------
 castsimpltop = everywhere ("castsimpl", castsimpl)
 
 --------------------------------
--- let recursification
+-- let derecursification
 --------------------------------
 --------------------------------
-letrec, letrectop :: Transform
-letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
+letderec, letderectop :: Transform
+letderec expr@(Let (Rec binds) res) = case liftable of
+  -- Nothing is liftable, just return
+  [] -> return expr
+  -- Something can be lifted, generate a new let expression
+  _ -> change $ MkCore.mkCoreLets newbinds res
+  where
+    -- Make a list of all the binders bound in this recursive let
+    bndrs = map fst binds
+    -- See which bindings are liftable
+    (liftable, nonliftable) = List.partition canlift binds
+    -- Create nonrec bindings for each liftable binding and a single recursive
+    -- binding for all others
+    newbinds = (map (uncurry NonRec) liftable) ++ [Rec nonliftable]
+    -- Any expression that does not use any of the binders in this recursive let
+    -- can be lifted into a nonrec let. It can't use its own binder either,
+    -- since that would mean the binding is self-recursive and should be in a
+    -- single bind recursive let.
+    canlift (bndr, e) = not $ expr_uses_binders bndrs e
 -- Leave all other expressions unchanged
 -- Leave all other expressions unchanged
-letrec expr = return expr
+letderec expr = return expr
 -- Perform this transform everywhere
 -- Perform this transform everywhere
-letrectop = everywhere ("letrec", letrec)
+letderectop = everywhere ("letderec", letderec)
 
 --------------------------------
 -- let simplification
 
 --------------------------------
 -- let simplification
@@ -130,7 +148,7 @@ letrectop = everywhere ("letrec", letrec)
 letsimpl, letsimpltop :: Transform
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
 letsimpl, letsimpltop :: Transform
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
-letsimpl expr@(Let (Rec binds) res) = do
+letsimpl expr@(Let binds res) = do
   repr <- isRepr res
   local_var <- Trans.lift $ is_local_var res
   if not local_var && repr
   repr <- isRepr res
   local_var <- Trans.lift $ is_local_var res
   if not local_var && repr
@@ -138,8 +156,7 @@ letsimpl expr@(Let (Rec binds) res) = do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
       id <- Trans.lift $ mkBinderFor res "foo"
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
       id <- Trans.lift $ mkBinderFor res "foo"
-      let bind = (id, res)
-      change $ Let (Rec (bind:binds)) (Var id)
+      change $ Let binds (Let (NonRec id  res) (Var id))
     else
       -- If the result is already a local var, don't extract it.
       return expr
     else
       -- If the result is already a local var, don't extract it.
       return expr
@@ -152,23 +169,14 @@ letsimpltop = everywhere ("letsimpl", letsimpl)
 --------------------------------
 -- let flattening
 --------------------------------
 --------------------------------
 -- let flattening
 --------------------------------
+-- Takes a let that binds another let, and turns that into two nested lets.
+-- e.g., from:
+-- let b = (let b' = expr' in res') in res
+-- to:
+-- let b' = expr' in (let b = res' in res)
 letflat, letflattop :: Transform
 letflat, letflattop :: Transform
-letflat (Let (Rec binds) expr) = do
-  -- Turn each binding into a list of bindings (possibly containing just one
-  -- element, of course)
-  bindss <- Monad.mapM flatbind binds
-  -- Concat all the bindings
-  let binds' = concat bindss
-  -- Return the new let. We don't use change here, since possibly nothing has
-  -- changed. If anything has changed, flatbind has already flagged that
-  -- change.
-  return $ Let (Rec binds') expr
-  where
-    -- Turns a binding of a let into a multiple bindings, or any other binding
-    -- into a list with just that binding
-    flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
-    flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
-    flatbind (b, expr) = return [(b, expr)]
+letflat (Let (NonRec b (Let (NonRec b' expr')  res')) res) = 
+  change $ Let (NonRec b' expr') (Let (NonRec b res') res)
 -- Leave all other expressions unchanged
 letflat expr = return expr
 -- Perform this transform everywhere
 -- Leave all other expressions unchanged
 letflat expr = return expr
 -- Perform this transform everywhere
@@ -262,7 +270,7 @@ scrutsimpl expr@(Case scrut b ty alts) = do
   if repr
     then do
       id <- Trans.lift $ mkBinderFor scrut "scrut"
   if repr
     then do
       id <- Trans.lift $ mkBinderFor scrut "scrut"
-      change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
+      change $ Let (NonRec id scrut) (Case (Var id) b ty alts)
     else
       return expr
 -- Leave all other expressions unchanged
     else
       return expr
 -- Leave all other expressions unchanged
@@ -396,7 +404,7 @@ appsimpl expr@(App f arg) = do
   if repr && not local_var
     then do -- Extract representable arguments
       id <- Trans.lift $ mkBinderFor arg "arg"
   if repr && not local_var
     then do -- Extract representable arguments
       id <- Trans.lift $ mkBinderFor arg "arg"
-      change $ Let (Rec [(id, arg)]) (App f (Var id))
+      change $ Let (NonRec id arg) (App f (Var id))
     else -- Leave non-representable arguments unchanged
       return expr
 -- Leave all other expressions unchanged
     else -- Leave non-representable arguments unchanged
       return expr
 -- Leave all other expressions unchanged
@@ -545,7 +553,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letderectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -569,15 +577,11 @@ normalizeExpr ::
   -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression
 
 normalizeExpr what expr = do
   -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression
 
 normalizeExpr what expr = do
-      -- Introduce an empty Let at the top level, so there will always be
-      -- a let in the expression (none of the transformations will remove
-      -- the last let).
-      let expr' = Let (Rec []) expr
       -- Normalize this expression
       -- Normalize this expression
-      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
-      expr'' <- dotransforms transforms expr'
-      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
-      return expr''
+      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr ) ++ "\n") $ return ()
+      expr' <- dotransforms transforms expr
+      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+      return expr'
 
 -- | Get the value that is bound to the given binder at top level. Fails when
 --   there is no such binding.
 
 -- | Get the value that is bound to the given binder at top level. Fails when
 --   there is no such binding.
@@ -595,9 +599,25 @@ getBinding bndr = Utils.makeCached bndr tsBindings $ do
 splitNormalized ::
   CoreExpr -- ^ The normalized expression
   -> ([CoreBndr], [Binding], CoreBndr)
 splitNormalized ::
   CoreExpr -- ^ The normalized expression
   -> ([CoreBndr], [Binding], CoreBndr)
-splitNormalized expr = 
-  case letexpr of
-    (Let (Rec binds) (Var res)) -> (args, binds, res)
-    _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+splitNormalized expr = (args, binds, res)
   where
     (args, letexpr) = CoreSyn.collectBinders expr
   where
     (args, letexpr) = CoreSyn.collectBinders expr
+    (binds, resexpr) = flattenLets letexpr
+    res = case resexpr of 
+      (Var x) -> x
+      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- | Flattens nested lets into a single list of bindings. The expression
+--   passed does not have to be a let expression, if it isn't an empty list of
+--   bindings is returned.
+flattenLets ::
+  CoreExpr -- ^ The expression to flatten.
+  -> ([Binding], CoreExpr) -- ^ The bindings and resulting expression.
+flattenLets (Let binds expr) = 
+  (bindings ++ bindings', expr')
+  where
+    -- Recursively flatten the contained expression
+    (bindings', expr') =flattenLets expr
+    -- Flatten our own bindings to remove the Rec / NonRec constructors
+    bindings = CoreSyn.flattenBinds [binds]
+flattenLets expr = ([], expr)