Split substitute into substitute and substitute_clone.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 647bd1902592c46db67c132a9ff744901a82b971..6fd6a2e03cb0df3c5e553a57306f294b48e5d1cb 100644 (file)
@@ -4,11 +4,12 @@
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
-module CLasH.Normalize (getNormalized, normalizeExpr) where
+module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where
 
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
+import qualified List
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
@@ -22,6 +23,7 @@ import qualified UniqSupply
 import qualified CoreUtils
 import qualified Type
 import qualified TcType
+import qualified Name
 import qualified Id
 import qualified Var
 import qualified VarSet
@@ -62,8 +64,10 @@ etatop = notappargs ("eta", eta)
 -- β-reduction
 --------------------------------
 beta, betatop :: Transform
--- Substitute arg for x in expr
-beta (App (Lam x expr) arg) = change $ substitute [(x, arg)] expr
+-- Substitute arg for x in expr. For value lambda's, also clone before
+-- substitution.
+beta (App (Lam x expr) arg) | CoreSyn.isTyVar x = setChanged >> substitute x arg expr
+                            | otherwise      = setChanged >> substitute_clone x arg expr
 -- Propagate the application into the let
 beta (App (Let binds expr) arg) = change $ Let binds (App expr arg)
 -- Propagate the application into each of the alternatives
@@ -91,31 +95,101 @@ castprop expr = return expr
 castproptop = everywhere ("castprop", castprop)
 
 --------------------------------
--- let recursification
+-- Cast simplification. Mostly useful for state packing and unpacking, but
+-- perhaps for others as well.
 --------------------------------
-letrec, letrectop :: Transform
-letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
+castsimpl, castsimpltop :: Transform
+castsimpl expr@(Cast val ty) = do
+  -- Don't extract values that are already simpl
+  local_var <- Trans.lift $ is_local_var val
+  -- Don't extract values that are not representable, to prevent loops with
+  -- inlinenonrep
+  repr <- isRepr val
+  if (not local_var) && repr
+    then do
+      -- Generate a binder for the expression
+      id <- Trans.lift $ mkBinderFor val "castval"
+      -- Extract the expression
+      change $ Let (NonRec id val) (Cast (Var id) ty)
+    else
+      return expr
+-- Leave all other expressions unchanged
+castsimpl expr = return expr
+-- Perform this transform everywhere
+castsimpltop = everywhere ("castsimpl", castsimpl)
+
+
+--------------------------------
+-- Lambda simplication
+--------------------------------
+-- Ensure that a lambda always evaluates to a let expressions or a simple
+-- variable reference.
+lambdasimpl, lambdasimpltop :: Transform
+-- Don't simplify a lambda that evaluates to let, since this is already
+-- normal form (and would cause infinite loops).
+lambdasimpl expr@(Lam _ (Let _ _)) = return expr
+-- Put the of a lambda in its own binding, but not when the expression is
+-- already a local variable, or not representable (to prevent loops with
+-- inlinenonrep).
+lambdasimpl expr@(Lam bndr res) = do
+  repr <- isRepr res
+  local_var <- Trans.lift $ is_local_var res
+  if not local_var && repr
+    then do
+      id <- Trans.lift $ mkBinderFor res "res"
+      change $ Lam bndr (Let (NonRec id res) (Var id))
+    else
+      -- If the result is already a local var or not representable, don't
+      -- extract it.
+      return expr
+
+-- Leave all other expressions unchanged
+lambdasimpl expr = return expr
+-- Perform this transform everywhere
+lambdasimpltop = everywhere ("lambdasimpl", lambdasimpl)
+
+--------------------------------
+-- let derecursification
+--------------------------------
+letderec, letderectop :: Transform
+letderec expr@(Let (Rec binds) res) = case liftable of
+  -- Nothing is liftable, just return
+  [] -> return expr
+  -- Something can be lifted, generate a new let expression
+  _ -> change $ mkNonRecLets liftable (Let (Rec nonliftable) res)
+  where
+    -- Make a list of all the binders bound in this recursive let
+    bndrs = map fst binds
+    -- See which bindings are liftable
+    (liftable, nonliftable) = List.partition canlift binds
+    -- Any expression that does not use any of the binders in this recursive let
+    -- can be lifted into a nonrec let. It can't use its own binder either,
+    -- since that would mean the binding is self-recursive and should be in a
+    -- single bind recursive let.
+    canlift (bndr, e) = not $ expr_uses_binders bndrs e
 -- Leave all other expressions unchanged
-letrec expr = return expr
+letderec expr = return expr
 -- Perform this transform everywhere
-letrectop = everywhere ("letrec", letrec)
+letderectop = everywhere ("letderec", letderec)
 
 --------------------------------
 -- let simplification
 --------------------------------
 letsimpl, letsimpltop :: Transform
+-- Don't simplify a let that evaluates to another let, since this is already
+-- normal form (and would cause infinite loops with letflat below).
+letsimpl expr@(Let _ (Let _ _)) = return expr
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
-letsimpl expr@(Let (Rec binds) res) = do
+letsimpl expr@(Let binds res) = do
   repr <- isRepr res
   local_var <- Trans.lift $ is_local_var res
   if not local_var && repr
     then do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
-      id <- Trans.lift $ mkInternalVar "foo" (CoreUtils.exprType res)
-      let bind = (id, res)
-      change $ Let (Rec (bind:binds)) (Var id)
+      id <- Trans.lift $ mkBinderFor res "foo"
+      change $ Let binds (Let (NonRec id  res) (Var id))
     else
       -- If the result is already a local var, don't extract it.
       return expr
@@ -128,13 +202,18 @@ letsimpltop = everywhere ("letsimpl", letsimpl)
 --------------------------------
 -- let flattening
 --------------------------------
+-- Takes a let that binds another let, and turns that into two nested lets.
+-- e.g., from:
+-- let b = (let b' = expr' in res') in res
+-- to:
+-- let b' = expr' in (let b = res' in res)
 letflat, letflattop :: Transform
+-- Turn a nonrec let that binds a let into two nested lets.
+letflat (Let (NonRec b (Let binds  res')) res) = 
+  change $ Let binds (Let (NonRec b res') res)
 letflat (Let (Rec binds) expr) = do
-  -- Turn each binding into a list of bindings (possibly containing just one
-  -- element, of course)
-  bindss <- Monad.mapM flatbind binds
-  -- Concat all the bindings
-  let binds' = concat bindss
+  -- Flatten each binding.
+  binds' <- Utils.concatM $ Monad.mapM flatbind binds
   -- Return the new let. We don't use change here, since possibly nothing has
   -- changed. If anything has changed, flatbind has already flagged that
   -- change.
@@ -144,23 +223,40 @@ letflat (Let (Rec binds) expr) = do
     -- into a list with just that binding
     flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
     flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
+    flatbind (b, Let (NonRec b' expr') expr) = change [(b, expr), (b', expr')]
     flatbind (b, expr) = return [(b, expr)]
 -- Leave all other expressions unchanged
 letflat expr = return expr
 -- Perform this transform everywhere
 letflattop = everywhere ("letflat", letflat)
 
+--------------------------------
+-- empty let removal
+--------------------------------
+-- Remove empty (recursive) lets
+letremove, letremovetop :: Transform
+letremove (Let (Rec []) res) = change $ res
+-- Leave all other expressions unchanged
+letremove expr = return expr
+-- Perform this transform everywhere
+letremovetop = everywhere ("letremove", letremove)
+
 --------------------------------
 -- Simple let binding removal
 --------------------------------
 -- Remove a = b bindings from let expressions everywhere
-letremovetop :: Transform
-letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
+letremovesimpletop :: Transform
+letremovesimpletop = everywhere ("letremovesimple", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
 
 --------------------------------
 -- Unused let binding removal
 --------------------------------
 letremoveunused, letremoveunusedtop :: Transform
+letremoveunused expr@(Let (NonRec b bound) res) = do
+  let used = expr_uses_binders [b] res
+  if used
+    then return expr
+    else change res
 letremoveunused expr@(Let (Rec binds) res) = do
   -- Filter out all unused binds.
   let binds' = filter dobind binds
@@ -170,11 +266,44 @@ letremoveunused expr@(Let (Rec binds) res) = do
       bound_exprs = map snd binds
       -- For each bind check if the bind is used by res or any of the bound
       -- expressions
-      dobind (bndr, _) = not $ any (expr_uses_binders [bndr]) (res:bound_exprs)
+      dobind (bndr, _) = any (expr_uses_binders [bndr]) (res:bound_exprs)
 -- Leave all other expressions unchanged
 letremoveunused expr = return expr
 letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
 
+{-
+--------------------------------
+-- Identical let binding merging
+--------------------------------
+-- Merge two bindings in a let if they are identical 
+-- TODO: We would very much like to use GHC's CSE module for this, but that
+-- doesn't track if something changed or not, so we can't use it properly.
+letmerge, letmergetop :: Transform
+letmerge expr@(Let _ _) = do
+  let (binds, res) = flattenLets expr
+  binds' <- domerge binds
+  return $ mkNonRecLets binds' res
+  where
+    domerge :: [(CoreBndr, CoreExpr)] -> TransformMonad [(CoreBndr, CoreExpr)]
+    domerge [] = return []
+    domerge (e:es) = do 
+      es' <- mapM (mergebinds e) es
+      es'' <- domerge es'
+      return (e:es'')
+
+    -- Uses the second bind to simplify the second bind, if applicable.
+    mergebinds :: (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
+    mergebinds (b1, e1) (b2, e2)
+      -- Identical expressions? Replace the second binding with a reference to
+      -- the first binder.
+      | CoreUtils.cheapEqExpr e1 e2 = change $ (b2, Var b1)
+      -- Different expressions? Don't change
+      | otherwise = return (b2, e2)
+-- Leave all other expressions unchanged
+letmerge expr = return expr
+letmergetop = everywhere ("letmerge", letmerge)
+-}
+
 --------------------------------
 -- Function inlining
 --------------------------------
@@ -193,6 +322,40 @@ letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
 inlinenonreptop :: Transform
 inlinenonreptop = everywhere ("inlinenonrep", inlinebind ((Monad.liftM not) . isRepr . snd))
 
+inlinetoplevel, inlinetopleveltop :: Transform
+-- Any system name is candidate for inlining. Never inline user-defined
+-- functions, to preserver structure.
+inlinetoplevel expr@(Var f) | not $ isUserDefined f = do
+  norm <- isNormalizeable f
+  -- See if this is a top level binding for which we have a body
+  body_maybe <- Trans.lift $ getGlobalBind f
+  if norm && Maybe.isJust body_maybe
+    then do
+      -- Get the normalized version
+      norm <- Trans.lift $ getNormalized f
+      if needsInline norm 
+        then do
+          -- Regenerate all uniques in the to-be-inlined expression
+          norm_uniqued <- Trans.lift $ genUniques norm
+          change norm_uniqued
+        else
+          return expr
+    else
+      -- No body or not normalizeable.
+      return expr
+-- Leave all other expressions unchanged
+inlinetoplevel expr = return expr
+inlinetopleveltop = everywhere ("inlinetoplevel", inlinetoplevel)
+
+needsInline :: CoreExpr -> Bool
+needsInline expr = case splitNormalized expr of
+  -- Inline any function that only has a single definition, it is probably
+  -- simple enough. This might inline some stuff that it shouldn't though it
+  -- will never inline user-defined functions (inlinetoplevel only tries
+  -- system names) and inlining should never break things.
+  (args, [bind], res) -> True
+  _ -> False
+
 --------------------------------
 -- Scrutinee simplification
 --------------------------------
@@ -207,8 +370,8 @@ scrutsimpl expr@(Case scrut b ty alts) = do
   repr <- isRepr scrut
   if repr
     then do
-      id <- Trans.lift $ mkInternalVar "scrut" (CoreUtils.exprType scrut)
-      change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
+      id <- Trans.lift $ mkBinderFor scrut "scrut"
+      change $ Let (NonRec id scrut) (Case (Var id) b ty alts)
     else
       return expr
 -- Leave all other expressions unchanged
@@ -234,7 +397,7 @@ casesimpl expr@(Case scrut b ty alts) = do
   (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts
   let bindings = concat bindingss
   -- Replace the case with a let with bindings and a case
-  let newlet = (Let (Rec bindings) (Case scrut b ty alts'))
+  let newlet = mkNonRecLets bindings (Case scrut b ty alts')
   -- If there are no non-wild binders, or this case is already a simple
   -- selector (i.e., a single alt with exactly one binding), already a simple
   -- selector altan no bindings (i.e., no wild binders in the original case),
@@ -257,7 +420,7 @@ casesimpl expr@(Case scrut b ty alts) = do
     (exprbinding_maybe, expr') <- doexpr expr uses_bndrs
     -- Create a new alternative
     let newalt = (con, newbndrs, expr')
-    let bindings = Maybe.catMaybes (exprbinding_maybe : bindings_maybe)
+    let bindings = Maybe.catMaybes (bindings_maybe ++ [exprbinding_maybe])
     return (bindings, newalt)
     where
       -- Make wild alternatives for each binder
@@ -269,7 +432,7 @@ casesimpl expr@(Case scrut b ty alts) = do
       -- binding containing a case expression.
       dobndr :: CoreBndr -> Int -> TransformMonad (CoreBndr, Maybe (CoreBndr, CoreExpr))
       dobndr b i = do
-        repr <- isRepr (Var b)
+        repr <- isRepr b
         -- Is b wild (e.g., not a free var of expr. Since b is only in scope
         -- in expr, this means that b is unused if expr does not use it.)
         let wild = not (VarSet.elemVarSet b free_vars)
@@ -301,7 +464,7 @@ casesimpl expr@(Case scrut b ty alts) = do
         -- prevent loops with inlinenonrep).
         if (not uses_bndrs) && (not local_var) && repr
           then do
-            id <- Trans.lift $ mkInternalVar "caseval" (CoreUtils.exprType expr)
+            id <- Trans.lift $ mkBinderFor expr "caseval"
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
             return $ (Just (id, expr), Var id)
@@ -341,8 +504,8 @@ appsimpl expr@(App f arg) = do
   local_var <- Trans.lift $ is_local_var arg
   if repr && not local_var
     then do -- Extract representable arguments
-      id <- Trans.lift $ mkInternalVar "arg" (CoreUtils.exprType arg)
-      change $ Let (Rec [(id, arg)]) (App f (Var id))
+      id <- Trans.lift $ mkBinderFor arg "arg"
+      change $ Let (NonRec id arg) (App f (Var id))
     else -- Leave non-representable arguments unchanged
       return expr
 -- Leave all other expressions unchanged
@@ -483,6 +646,25 @@ funextract expr = return expr
 -- Perform this transform everywhere
 funextracttop = everywhere ("funextract", funextract)
 
+--------------------------------
+-- Ensure that a function that just returns another function (or rather,
+-- another top-level binder) is still properly normalized. This is a temporary
+-- solution, we should probably integrate this pass with lambdasimpl and
+-- letsimpl instead.
+--------------------------------
+simplrestop expr@(Lam _ _) = return expr
+simplrestop expr@(Let _ _) = return expr
+simplrestop expr = do
+  local_var <- Trans.lift $ is_local_var expr
+  -- Don't extract values that are not representable, to prevent loops with
+  -- inlinenonrep
+  repr <- isRepr expr
+  if local_var || not repr
+    then
+      return expr
+    else do
+      id <- Trans.lift $ mkBinderFor expr "res" 
+      change $ Let (NonRec id expr) (Var id)
 --------------------------------
 -- End of transformations
 --------------------------------
@@ -491,7 +673,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop]
+transforms = [inlinetopleveltop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -515,15 +697,12 @@ normalizeExpr ::
   -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression
 
 normalizeExpr what expr = do
-      -- Introduce an empty Let at the top level, so there will always be
-      -- a let in the expression (none of the transformations will remove
-      -- the last let).
-      let expr' = Let (Rec []) expr
+      expr_uniqued <- genUniques expr
       -- Normalize this expression
-      trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
-      expr'' <- dotransforms transforms expr'
-      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
-      return expr''
+      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr_uniqued ) ++ "\n") $ return ()
+      expr' <- dotransforms transforms expr_uniqued
+      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+      return expr'
 
 -- | Get the value that is bound to the given binder at top level. Fails when
 --   there is no such binding.
@@ -535,3 +714,16 @@ getBinding bndr = Utils.makeCached bndr tsBindings $ do
   -- If the binding isn't in the "cache" (bindings map), then we can't create
   -- it out of thin air, so return an error.
   error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
+
+-- | Split a normalized expression into the argument binders, top level
+--   bindings and the result binder.
+splitNormalized ::
+  CoreExpr -- ^ The normalized expression
+  -> ([CoreBndr], [Binding], CoreBndr)
+splitNormalized expr = (args, binds, res)
+  where
+    (args, letexpr) = CoreSyn.collectBinders expr
+    (binds, resexpr) = flattenLets letexpr
+    res = case resexpr of 
+      (Var x) -> x
+      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"