Merge branch 'master' of git://github.com/christiaanb/clash into cλash
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 7fb0dc235d2a60d8a3e3798b8a655e9f6f3218f7..9828d5ceea96704f92b979766dd06be5bfcc7523 100644 (file)
@@ -9,6 +9,7 @@ module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
+import qualified List
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
@@ -115,14 +116,31 @@ castsimpl expr = return expr
 castsimpltop = everywhere ("castsimpl", castsimpl)
 
 --------------------------------
--- let recursification
+-- let derecursification
 --------------------------------
-letrec, letrectop :: Transform
-letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
+letderec, letderectop :: Transform
+letderec expr@(Let (Rec binds) res) = case liftable of
+  -- Nothing is liftable, just return
+  [] -> return expr
+  -- Something can be lifted, generate a new let expression
+  _ -> change $ MkCore.mkCoreLets newbinds res
+  where
+    -- Make a list of all the binders bound in this recursive let
+    bndrs = map fst binds
+    -- See which bindings are liftable
+    (liftable, nonliftable) = List.partition canlift binds
+    -- Create nonrec bindings for each liftable binding and a single recursive
+    -- binding for all others
+    newbinds = (map (uncurry NonRec) liftable) ++ [Rec nonliftable]
+    -- Any expression that does not use any of the binders in this recursive let
+    -- can be lifted into a nonrec let. It can't use its own binder either,
+    -- since that would mean the binding is self-recursive and should be in a
+    -- single bind recursive let.
+    canlift (bndr, e) = not $ expr_uses_binders bndrs e
 -- Leave all other expressions unchanged
-letrec expr = return expr
+letderec expr = return expr
 -- Perform this transform everywhere
-letrectop = everywhere ("letrec", letrec)
+letderectop = everywhere ("letderec", letderec)
 
 --------------------------------
 -- let simplification
@@ -545,7 +563,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letderectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -595,9 +613,25 @@ getBinding bndr = Utils.makeCached bndr tsBindings $ do
 splitNormalized ::
   CoreExpr -- ^ The normalized expression
   -> ([CoreBndr], [Binding], CoreBndr)
-splitNormalized expr = 
-  case letexpr of
-    (Let (Rec binds) (Var res)) -> (args, binds, res)
-    _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+splitNormalized expr = (args, binds, res)
   where
     (args, letexpr) = CoreSyn.collectBinders expr
+    (binds, resexpr) = flattenLets letexpr
+    res = case resexpr of 
+      (Var x) -> x
+      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- | Flattens nested lets into a single list of bindings. The expression
+--   passed does not have to be a let expression, if it isn't an empty list of
+--   bindings is returned.
+flattenLets ::
+  CoreExpr -- ^ The expression to flatten.
+  -> ([Binding], CoreExpr) -- ^ The bindings and resulting expression.
+flattenLets (Let binds expr) = 
+  (bindings ++ bindings', expr')
+  where
+    -- Recursively flatten the contained expression
+    (bindings', expr') =flattenLets expr
+    -- Flatten our own bindings to remove the Rec / NonRec constructors
+    bindings = CoreSyn.flattenBinds [binds]
+flattenLets expr = ([], expr)