Merge branch 'master' of git://github.com/christiaanb/clash into cλash
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 6000bb127832d380bb84b39981cf3c34b2e97e90..9828d5ceea96704f92b979766dd06be5bfcc7523 100644 (file)
@@ -4,11 +4,12 @@
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
-module CLasH.Normalize (getNormalized, normalizeExpr) where
+module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where
 
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
+import qualified List
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
@@ -115,14 +116,31 @@ castsimpl expr = return expr
 castsimpltop = everywhere ("castsimpl", castsimpl)
 
 --------------------------------
--- let recursification
+-- let derecursification
 --------------------------------
-letrec, letrectop :: Transform
-letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
+letderec, letderectop :: Transform
+letderec expr@(Let (Rec binds) res) = case liftable of
+  -- Nothing is liftable, just return
+  [] -> return expr
+  -- Something can be lifted, generate a new let expression
+  _ -> change $ MkCore.mkCoreLets newbinds res
+  where
+    -- Make a list of all the binders bound in this recursive let
+    bndrs = map fst binds
+    -- See which bindings are liftable
+    (liftable, nonliftable) = List.partition canlift binds
+    -- Create nonrec bindings for each liftable binding and a single recursive
+    -- binding for all others
+    newbinds = (map (uncurry NonRec) liftable) ++ [Rec nonliftable]
+    -- Any expression that does not use any of the binders in this recursive let
+    -- can be lifted into a nonrec let. It can't use its own binder either,
+    -- since that would mean the binding is self-recursive and should be in a
+    -- single bind recursive let.
+    canlift (bndr, e) = not $ expr_uses_binders bndrs e
 -- Leave all other expressions unchanged
-letrec expr = return expr
+letderec expr = return expr
 -- Perform this transform everywhere
-letrectop = everywhere ("letrec", letrec)
+letderectop = everywhere ("letderec", letderec)
 
 --------------------------------
 -- let simplification
@@ -137,7 +155,7 @@ letsimpl expr@(Let (Rec binds) res) = do
     then do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
-      id <- Trans.lift $ mkInternalVar "foo" (CoreUtils.exprType res)
+      id <- Trans.lift $ mkBinderFor res "foo"
       let bind = (id, res)
       change $ Let (Rec (bind:binds)) (Var id)
     else
@@ -261,7 +279,7 @@ scrutsimpl expr@(Case scrut b ty alts) = do
   repr <- isRepr scrut
   if repr
     then do
-      id <- Trans.lift $ mkInternalVar "scrut" (CoreUtils.exprType scrut)
+      id <- Trans.lift $ mkBinderFor scrut "scrut"
       change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
     else
       return expr
@@ -355,7 +373,7 @@ casesimpl expr@(Case scrut b ty alts) = do
         -- prevent loops with inlinenonrep).
         if (not uses_bndrs) && (not local_var) && repr
           then do
-            id <- Trans.lift $ mkInternalVar "caseval" (CoreUtils.exprType expr)
+            id <- Trans.lift $ mkBinderFor expr "caseval"
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
             return $ (Just (id, expr), Var id)
@@ -395,7 +413,7 @@ appsimpl expr@(App f arg) = do
   local_var <- Trans.lift $ is_local_var arg
   if repr && not local_var
     then do -- Extract representable arguments
-      id <- Trans.lift $ mkInternalVar "arg" (CoreUtils.exprType arg)
+      id <- Trans.lift $ mkBinderFor arg "arg"
       change $ Let (Rec [(id, arg)]) (App f (Var id))
     else -- Leave non-representable arguments unchanged
       return expr
@@ -545,7 +563,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letderectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -574,9 +592,9 @@ normalizeExpr what expr = do
       -- the last let).
       let expr' = Let (Rec []) expr
       -- Normalize this expression
-      trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
+      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
       expr'' <- dotransforms transforms expr'
-      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
+      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
       return expr''
 
 -- | Get the value that is bound to the given binder at top level. Fails when
@@ -589,3 +607,31 @@ getBinding bndr = Utils.makeCached bndr tsBindings $ do
   -- If the binding isn't in the "cache" (bindings map), then we can't create
   -- it out of thin air, so return an error.
   error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
+
+-- | Split a normalized expression into the argument binders, top level
+--   bindings and the result binder.
+splitNormalized ::
+  CoreExpr -- ^ The normalized expression
+  -> ([CoreBndr], [Binding], CoreBndr)
+splitNormalized expr = (args, binds, res)
+  where
+    (args, letexpr) = CoreSyn.collectBinders expr
+    (binds, resexpr) = flattenLets letexpr
+    res = case resexpr of 
+      (Var x) -> x
+      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- | Flattens nested lets into a single list of bindings. The expression
+--   passed does not have to be a let expression, if it isn't an empty list of
+--   bindings is returned.
+flattenLets ::
+  CoreExpr -- ^ The expression to flatten.
+  -> ([Binding], CoreExpr) -- ^ The bindings and resulting expression.
+flattenLets (Let binds expr) = 
+  (bindings ++ bindings', expr')
+  where
+    -- Recursively flatten the contained expression
+    (bindings', expr') =flattenLets expr
+    -- Flatten our own bindings to remove the Rec / NonRec constructors
+    bindings = CoreSyn.flattenBinds [binds]
+flattenLets expr = ([], expr)