Merge branch 'master' of git://github.com/christiaanb/clash into cλash
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index f3a9d2e4662ca2e204d761fa1c80fc4272628a73..9828d5ceea96704f92b979766dd06be5bfcc7523 100644 (file)
@@ -4,11 +4,12 @@
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
-module CLasH.Normalize (getNormalized, normalizeExpr) where
+module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where
 
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
 
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
+import qualified List
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
@@ -115,14 +116,31 @@ castsimpl expr = return expr
 castsimpltop = everywhere ("castsimpl", castsimpl)
 
 --------------------------------
 castsimpltop = everywhere ("castsimpl", castsimpl)
 
 --------------------------------
--- let recursification
+-- let derecursification
 --------------------------------
 --------------------------------
-letrec, letrectop :: Transform
-letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
+letderec, letderectop :: Transform
+letderec expr@(Let (Rec binds) res) = case liftable of
+  -- Nothing is liftable, just return
+  [] -> return expr
+  -- Something can be lifted, generate a new let expression
+  _ -> change $ MkCore.mkCoreLets newbinds res
+  where
+    -- Make a list of all the binders bound in this recursive let
+    bndrs = map fst binds
+    -- See which bindings are liftable
+    (liftable, nonliftable) = List.partition canlift binds
+    -- Create nonrec bindings for each liftable binding and a single recursive
+    -- binding for all others
+    newbinds = (map (uncurry NonRec) liftable) ++ [Rec nonliftable]
+    -- Any expression that does not use any of the binders in this recursive let
+    -- can be lifted into a nonrec let. It can't use its own binder either,
+    -- since that would mean the binding is self-recursive and should be in a
+    -- single bind recursive let.
+    canlift (bndr, e) = not $ expr_uses_binders bndrs e
 -- Leave all other expressions unchanged
 -- Leave all other expressions unchanged
-letrec expr = return expr
+letderec expr = return expr
 -- Perform this transform everywhere
 -- Perform this transform everywhere
-letrectop = everywhere ("letrec", letrec)
+letderectop = everywhere ("letderec", letderec)
 
 --------------------------------
 -- let simplification
 
 --------------------------------
 -- let simplification
@@ -545,7 +563,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letderectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -574,9 +592,9 @@ normalizeExpr what expr = do
       -- the last let).
       let expr' = Let (Rec []) expr
       -- Normalize this expression
       -- the last let).
       let expr' = Let (Rec []) expr
       -- Normalize this expression
-      trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
+      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
       expr'' <- dotransforms transforms expr'
       expr'' <- dotransforms transforms expr'
-      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
+      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
       return expr''
 
 -- | Get the value that is bound to the given binder at top level. Fails when
       return expr''
 
 -- | Get the value that is bound to the given binder at top level. Fails when
@@ -589,3 +607,31 @@ getBinding bndr = Utils.makeCached bndr tsBindings $ do
   -- If the binding isn't in the "cache" (bindings map), then we can't create
   -- it out of thin air, so return an error.
   error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
   -- If the binding isn't in the "cache" (bindings map), then we can't create
   -- it out of thin air, so return an error.
   error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
+
+-- | Split a normalized expression into the argument binders, top level
+--   bindings and the result binder.
+splitNormalized ::
+  CoreExpr -- ^ The normalized expression
+  -> ([CoreBndr], [Binding], CoreBndr)
+splitNormalized expr = (args, binds, res)
+  where
+    (args, letexpr) = CoreSyn.collectBinders expr
+    (binds, resexpr) = flattenLets letexpr
+    res = case resexpr of 
+      (Var x) -> x
+      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- | Flattens nested lets into a single list of bindings. The expression
+--   passed does not have to be a let expression, if it isn't an empty list of
+--   bindings is returned.
+flattenLets ::
+  CoreExpr -- ^ The expression to flatten.
+  -> ([Binding], CoreExpr) -- ^ The bindings and resulting expression.
+flattenLets (Let binds expr) = 
+  (bindings ++ bindings', expr')
+  where
+    -- Recursively flatten the contained expression
+    (bindings', expr') =flattenLets expr
+    -- Flatten our own bindings to remove the Rec / NonRec constructors
+    bindings = CoreSyn.flattenBinds [binds]
+flattenLets expr = ([], expr)