When inlining top level functions, guarantee uniqueness.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 6a4825d42406343c5b6eb94191fdaf5b4ed554e0..1f0509da5a2ee674aae1117c9150615794152638 100644 (file)
@@ -23,6 +23,7 @@ import qualified UniqSupply
 import qualified CoreUtils
 import qualified Type
 import qualified TcType
+import qualified Name
 import qualified Id
 import qualified Var
 import qualified VarSet
@@ -115,6 +116,36 @@ castsimpl expr = return expr
 -- Perform this transform everywhere
 castsimpltop = everywhere ("castsimpl", castsimpl)
 
+
+--------------------------------
+-- Lambda simplication
+--------------------------------
+-- Ensure that a lambda always evaluates to a let expressions or a simple
+-- variable reference.
+lambdasimpl, lambdasimpltop :: Transform
+-- Don't simplify a lambda that evaluates to let, since this is already
+-- normal form (and would cause infinite loops).
+lambdasimpl expr@(Lam _ (Let _ _)) = return expr
+-- Put the of a lambda in its own binding, but not when the expression is
+-- already a local variable, or not representable (to prevent loops with
+-- inlinenonrep).
+lambdasimpl expr@(Lam bndr res) = do
+  repr <- isRepr res
+  local_var <- Trans.lift $ is_local_var res
+  if not local_var && repr
+    then do
+      id <- Trans.lift $ mkBinderFor res "res"
+      change $ Lam bndr (Let (NonRec id res) (Var id))
+    else
+      -- If the result is already a local var or not representable, don't
+      -- extract it.
+      return expr
+
+-- Leave all other expressions unchanged
+lambdasimpl expr = return expr
+-- Perform this transform everywhere
+lambdasimpltop = everywhere ("lambdasimpl", lambdasimpl)
+
 --------------------------------
 -- let derecursification
 --------------------------------
@@ -123,15 +154,12 @@ letderec expr@(Let (Rec binds) res) = case liftable of
   -- Nothing is liftable, just return
   [] -> return expr
   -- Something can be lifted, generate a new let expression
-  _ -> change $ MkCore.mkCoreLets newbinds res
+  _ -> change $ mkNonRecLets liftable (Let (Rec nonliftable) res)
   where
     -- Make a list of all the binders bound in this recursive let
     bndrs = map fst binds
     -- See which bindings are liftable
     (liftable, nonliftable) = List.partition canlift binds
-    -- Create nonrec bindings for each liftable binding and a single recursive
-    -- binding for all others
-    newbinds = (map (uncurry NonRec) liftable) ++ [Rec nonliftable]
     -- Any expression that does not use any of the binders in this recursive let
     -- can be lifted into a nonrec let. It can't use its own binder either,
     -- since that would mean the binding is self-recursive and should be in a
@@ -178,24 +206,55 @@ letsimpltop = everywhere ("letsimpl", letsimpl)
 -- to:
 -- let b' = expr' in (let b = res' in res)
 letflat, letflattop :: Transform
-letflat (Let (NonRec b (Let (NonRec b' expr')  res')) res) = 
-  change $ Let (NonRec b' expr') (Let (NonRec b res') res)
+-- Turn a nonrec let that binds a let into two nested lets.
+letflat (Let (NonRec b (Let binds  res')) res) = 
+  change $ Let binds (Let (NonRec b res') res)
+letflat (Let (Rec binds) expr) = do
+  -- Flatten each binding.
+  binds' <- Utils.concatM $ Monad.mapM flatbind binds
+  -- Return the new let. We don't use change here, since possibly nothing has
+  -- changed. If anything has changed, flatbind has already flagged that
+  -- change.
+  return $ Let (Rec binds') expr
+  where
+    -- Turns a binding of a let into a multiple bindings, or any other binding
+    -- into a list with just that binding
+    flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
+    flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
+    flatbind (b, Let (NonRec b' expr') expr) = change [(b, expr), (b', expr')]
+    flatbind (b, expr) = return [(b, expr)]
 -- Leave all other expressions unchanged
 letflat expr = return expr
 -- Perform this transform everywhere
 letflattop = everywhere ("letflat", letflat)
 
+--------------------------------
+-- empty let removal
+--------------------------------
+-- Remove empty (recursive) lets
+letremove, letremovetop :: Transform
+letremove (Let (Rec []) res) = change $ res
+-- Leave all other expressions unchanged
+letremove expr = return expr
+-- Perform this transform everywhere
+letremovetop = everywhere ("letremove", letremove)
+
 --------------------------------
 -- Simple let binding removal
 --------------------------------
 -- Remove a = b bindings from let expressions everywhere
-letremovetop :: Transform
-letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
+letremovesimpletop :: Transform
+letremovesimpletop = everywhere ("letremovesimple", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
 
 --------------------------------
 -- Unused let binding removal
 --------------------------------
 letremoveunused, letremoveunusedtop :: Transform
+letremoveunused expr@(Let (NonRec b bound) res) = do
+  let used = expr_uses_binders [b] res
+  if used
+    then return expr
+    else change res
 letremoveunused expr@(Let (Rec binds) res) = do
   -- Filter out all unused binds.
   let binds' = filter dobind binds
@@ -210,6 +269,7 @@ letremoveunused expr@(Let (Rec binds) res) = do
 letremoveunused expr = return expr
 letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
 
+{-
 --------------------------------
 -- Identical let binding merging
 --------------------------------
@@ -217,9 +277,10 @@ letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
 -- TODO: We would very much like to use GHC's CSE module for this, but that
 -- doesn't track if something changed or not, so we can't use it properly.
 letmerge, letmergetop :: Transform
-letmerge expr@(Let (Rec binds) res) = do
+letmerge expr@(Let _ _) = do
+  let (binds, res) = flattenLets expr
   binds' <- domerge binds
-  return (Let (Rec binds') res)
+  return $ mkNonRecLets binds' res
   where
     domerge :: [(CoreBndr, CoreExpr)] -> TransformMonad [(CoreBndr, CoreExpr)]
     domerge [] = return []
@@ -239,7 +300,8 @@ letmerge expr@(Let (Rec binds) res) = do
 -- Leave all other expressions unchanged
 letmerge expr = return expr
 letmergetop = everywhere ("letmerge", letmerge)
-    
+-}
+
 --------------------------------
 -- Function inlining
 --------------------------------
@@ -258,6 +320,40 @@ letmergetop = everywhere ("letmerge", letmerge)
 inlinenonreptop :: Transform
 inlinenonreptop = everywhere ("inlinenonrep", inlinebind ((Monad.liftM not) . isRepr . snd))
 
+inlinetoplevel, inlinetopleveltop :: Transform
+-- Any system name is candidate for inlining. Never inline user-defined
+-- functions, to preserver structure.
+inlinetoplevel expr@(Var f) | not $ isUserDefined f = do
+  norm <- isNormalizeable f
+  -- See if this is a top level binding for which we have a body
+  body_maybe <- Trans.lift $ getGlobalBind f
+  if norm && Maybe.isJust body_maybe
+    then do
+      -- Get the normalized version
+      norm <- Trans.lift $ getNormalized f
+      if needsInline norm 
+        then do
+          -- Regenerate all uniques in the to-be-inlined expression
+          norm_uniqued <- Trans.lift $ genUniques norm
+          change norm_uniqued
+        else
+          return expr
+    else
+      -- No body or not normalizeable.
+      return expr
+-- Leave all other expressions unchanged
+inlinetoplevel expr = return expr
+inlinetopleveltop = everywhere ("inlinetoplevel", inlinetoplevel)
+
+needsInline :: CoreExpr -> Bool
+needsInline expr = case splitNormalized expr of
+  -- Inline any function that only has a single definition, it is probably
+  -- simple enough. This might inline some stuff that it shouldn't though it
+  -- will never inline user-defined functions (inlinetoplevel only tries
+  -- system names) and inlining should never break things.
+  (args, [bind], res) -> True
+  _ -> False
+
 --------------------------------
 -- Scrutinee simplification
 --------------------------------
@@ -299,7 +395,7 @@ casesimpl expr@(Case scrut b ty alts) = do
   (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts
   let bindings = concat bindingss
   -- Replace the case with a let with bindings and a case
-  let newlet = (Let (Rec bindings) (Case scrut b ty alts'))
+  let newlet = mkNonRecLets bindings (Case scrut b ty alts')
   -- If there are no non-wild binders, or this case is already a simple
   -- selector (i.e., a single alt with exactly one binding), already a simple
   -- selector altan no bindings (i.e., no wild binders in the original case),
@@ -322,7 +418,7 @@ casesimpl expr@(Case scrut b ty alts) = do
     (exprbinding_maybe, expr') <- doexpr expr uses_bndrs
     -- Create a new alternative
     let newalt = (con, newbndrs, expr')
-    let bindings = Maybe.catMaybes (exprbinding_maybe : bindings_maybe)
+    let bindings = Maybe.catMaybes (bindings_maybe ++ [exprbinding_maybe])
     return (bindings, newalt)
     where
       -- Make wild alternatives for each binder
@@ -334,7 +430,7 @@ casesimpl expr@(Case scrut b ty alts) = do
       -- binding containing a case expression.
       dobndr :: CoreBndr -> Int -> TransformMonad (CoreBndr, Maybe (CoreBndr, CoreExpr))
       dobndr b i = do
-        repr <- isRepr (Var b)
+        repr <- isRepr b
         -- Is b wild (e.g., not a free var of expr. Since b is only in scope
         -- in expr, this means that b is unused if expr does not use it.)
         let wild = not (VarSet.elemVarSet b free_vars)
@@ -548,6 +644,25 @@ funextract expr = return expr
 -- Perform this transform everywhere
 funextracttop = everywhere ("funextract", funextract)
 
+--------------------------------
+-- Ensure that a function that just returns another function (or rather,
+-- another top-level binder) is still properly normalized. This is a temporary
+-- solution, we should probably integrate this pass with lambdasimpl and
+-- letsimpl instead.
+--------------------------------
+simplrestop expr@(Lam _ _) = return expr
+simplrestop expr@(Let _ _) = return expr
+simplrestop expr = do
+  local_var <- Trans.lift $ is_local_var expr
+  -- Don't extract values that are not representable, to prevent loops with
+  -- inlinenonrep
+  repr <- isRepr expr
+  if local_var || not repr
+    then
+      return expr
+    else do
+      id <- Trans.lift $ mkBinderFor expr "res" 
+      change $ Let (NonRec id expr) (Var id)
 --------------------------------
 -- End of transformations
 --------------------------------
@@ -556,7 +671,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letderectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
+transforms = [inlinetopleveltop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -580,9 +695,10 @@ normalizeExpr ::
   -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression
 
 normalizeExpr what expr = do
+      expr_uniqued <- genUniques expr
       -- Normalize this expression
-      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr ) ++ "\n") $ return ()
-      expr' <- dotransforms transforms expr
+      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr_uniqued ) ++ "\n") $ return ()
+      expr' <- dotransforms transforms expr_uniqued
       trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return ()
       return expr'
 
@@ -609,18 +725,3 @@ splitNormalized expr = (args, binds, res)
     res = case resexpr of 
       (Var x) -> x
       _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
-
--- | Flattens nested lets into a single list of bindings. The expression
---   passed does not have to be a let expression, if it isn't an empty list of
---   bindings is returned.
-flattenLets ::
-  CoreExpr -- ^ The expression to flatten.
-  -> ([Binding], CoreExpr) -- ^ The bindings and resulting expression.
-flattenLets (Let binds expr) = 
-  (bindings ++ bindings', expr')
-  where
-    -- Recursively flatten the contained expression
-    (bindings', expr') =flattenLets expr
-    -- Flatten our own bindings to remove the Rec / NonRec constructors
-    bindings = CoreSyn.flattenBinds [binds]
-flattenLets expr = ([], expr)