Fix typo.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 8ec195b0ef936aadd89449571988da9e3c4f56e0..0d352761dbafec889da2345d1b82bc9dab0b4542 100644 (file)
@@ -4,11 +4,12 @@
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
-module CLasH.Normalize (normalizeModule) where
+module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where
 
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
 
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
+import qualified List
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
@@ -22,6 +23,7 @@ import qualified UniqSupply
 import qualified CoreUtils
 import qualified Type
 import qualified TcType
 import qualified CoreUtils
 import qualified Type
 import qualified TcType
+import qualified Name
 import qualified Id
 import qualified Var
 import qualified VarSet
 import qualified Id
 import qualified Var
 import qualified VarSet
@@ -34,9 +36,12 @@ import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
+import CLasH.Translator.TranslatorTypes
 import CLasH.Normalize.NormalizeTools
 import CLasH.VHDL.VHDLTypes
 import CLasH.Normalize.NormalizeTools
 import CLasH.VHDL.VHDLTypes
+import qualified CLasH.Utils as Utils
 import CLasH.Utils.Core.CoreTools
 import CLasH.Utils.Core.CoreTools
+import CLasH.Utils.Core.BinderTools
 import CLasH.Utils.Pretty
 
 --------------------------------
 import CLasH.Utils.Pretty
 
 --------------------------------
@@ -49,7 +54,7 @@ import CLasH.Utils.Pretty
 eta, etatop :: Transform
 eta expr | is_fun expr && not (is_lam expr) = do
   let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
 eta, etatop :: Transform
 eta expr | is_fun expr && not (is_lam expr) = do
   let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
-  id <- mkInternalVar "param" arg_ty
+  id <- Trans.lift $ mkInternalVar "param" arg_ty
   change (Lam id (App expr (Var id)))
 -- Leave all other expressions unchanged
 eta e = return e
   change (Lam id (App expr (Var id)))
 -- Leave all other expressions unchanged
 eta e = return e
@@ -59,8 +64,10 @@ etatop = notappargs ("eta", eta)
 -- β-reduction
 --------------------------------
 beta, betatop :: Transform
 -- β-reduction
 --------------------------------
 beta, betatop :: Transform
--- Substitute arg for x in expr
-beta (App (Lam x expr) arg) = change $ substitute [(x, arg)] expr
+-- Substitute arg for x in expr. For value lambda's, also clone before
+-- substitution.
+beta (App (Lam x expr) arg) | CoreSyn.isTyVar x = setChanged >> substitute x arg expr
+                            | otherwise      = setChanged >> substitute_clone x arg expr
 -- Propagate the application into the let
 beta (App (Let binds expr) arg) = change $ Let binds (App expr arg)
 -- Propagate the application into each of the alternatives
 -- Propagate the application into the let
 beta (App (Let binds expr) arg) = change $ Let binds (App expr arg)
 -- Propagate the application into each of the alternatives
@@ -88,31 +95,101 @@ castprop expr = return expr
 castproptop = everywhere ("castprop", castprop)
 
 --------------------------------
 castproptop = everywhere ("castprop", castprop)
 
 --------------------------------
--- let recursification
+-- Cast simplification. Mostly useful for state packing and unpacking, but
+-- perhaps for others as well.
+--------------------------------
+castsimpl, castsimpltop :: Transform
+castsimpl expr@(Cast val ty) = do
+  -- Don't extract values that are already simpl
+  local_var <- Trans.lift $ is_local_var val
+  -- Don't extract values that are not representable, to prevent loops with
+  -- inlinenonrep
+  repr <- isRepr val
+  if (not local_var) && repr
+    then do
+      -- Generate a binder for the expression
+      id <- Trans.lift $ mkBinderFor val "castval"
+      -- Extract the expression
+      change $ Let (NonRec id val) (Cast (Var id) ty)
+    else
+      return expr
+-- Leave all other expressions unchanged
+castsimpl expr = return expr
+-- Perform this transform everywhere
+castsimpltop = everywhere ("castsimpl", castsimpl)
+
+
+--------------------------------
+-- Lambda simplication
+--------------------------------
+-- Ensure that a lambda always evaluates to a let expressions or a simple
+-- variable reference.
+lambdasimpl, lambdasimpltop :: Transform
+-- Don't simplify a lambda that evaluates to let, since this is already
+-- normal form (and would cause infinite loops).
+lambdasimpl expr@(Lam _ (Let _ _)) = return expr
+-- Put the of a lambda in its own binding, but not when the expression is
+-- already a local variable, or not representable (to prevent loops with
+-- inlinenonrep).
+lambdasimpl expr@(Lam bndr res) = do
+  repr <- isRepr res
+  local_var <- Trans.lift $ is_local_var res
+  if not local_var && repr
+    then do
+      id <- Trans.lift $ mkBinderFor res "res"
+      change $ Lam bndr (Let (NonRec id res) (Var id))
+    else
+      -- If the result is already a local var or not representable, don't
+      -- extract it.
+      return expr
+
+-- Leave all other expressions unchanged
+lambdasimpl expr = return expr
+-- Perform this transform everywhere
+lambdasimpltop = everywhere ("lambdasimpl", lambdasimpl)
+
+--------------------------------
+-- let derecursification
 --------------------------------
 --------------------------------
-letrec, letrectop :: Transform
-letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
+letderec, letderectop :: Transform
+letderec expr@(Let (Rec binds) res) = case liftable of
+  -- Nothing is liftable, just return
+  [] -> return expr
+  -- Something can be lifted, generate a new let expression
+  _ -> change $ mkNonRecLets liftable (Let (Rec nonliftable) res)
+  where
+    -- Make a list of all the binders bound in this recursive let
+    bndrs = map fst binds
+    -- See which bindings are liftable
+    (liftable, nonliftable) = List.partition canlift binds
+    -- Any expression that does not use any of the binders in this recursive let
+    -- can be lifted into a nonrec let. It can't use its own binder either,
+    -- since that would mean the binding is self-recursive and should be in a
+    -- single bind recursive let.
+    canlift (bndr, e) = not $ expr_uses_binders bndrs e
 -- Leave all other expressions unchanged
 -- Leave all other expressions unchanged
-letrec expr = return expr
+letderec expr = return expr
 -- Perform this transform everywhere
 -- Perform this transform everywhere
-letrectop = everywhere ("letrec", letrec)
+letderectop = everywhere ("letderec", letderec)
 
 --------------------------------
 -- let simplification
 --------------------------------
 letsimpl, letsimpltop :: Transform
 
 --------------------------------
 -- let simplification
 --------------------------------
 letsimpl, letsimpltop :: Transform
+-- Don't simplify a let that evaluates to another let, since this is already
+-- normal form (and would cause infinite loops with letflat below).
+letsimpl expr@(Let _ (Let _ _)) = return expr
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
-letsimpl expr@(Let (Rec binds) res) = do
+letsimpl expr@(Let binds res) = do
   repr <- isRepr res
   local_var <- Trans.lift $ is_local_var res
   if not local_var && repr
     then do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
   repr <- isRepr res
   local_var <- Trans.lift $ is_local_var res
   if not local_var && repr
     then do
       -- If the result is not a local var already (to prevent loops with
       -- ourselves), extract it.
-      id <- mkInternalVar "foo" (CoreUtils.exprType res)
-      let bind = (id, res)
-      change $ Let (Rec (bind:binds)) (Var id)
+      id <- Trans.lift $ mkBinderFor res "foo"
+      change $ Let binds (Let (NonRec id  res) (Var id))
     else
       -- If the result is already a local var, don't extract it.
       return expr
     else
       -- If the result is already a local var, don't extract it.
       return expr
@@ -125,13 +202,18 @@ letsimpltop = everywhere ("letsimpl", letsimpl)
 --------------------------------
 -- let flattening
 --------------------------------
 --------------------------------
 -- let flattening
 --------------------------------
+-- Takes a let that binds another let, and turns that into two nested lets.
+-- e.g., from:
+-- let b = (let b' = expr' in res') in res
+-- to:
+-- let b' = expr' in (let b = res' in res)
 letflat, letflattop :: Transform
 letflat, letflattop :: Transform
+-- Turn a nonrec let that binds a let into two nested lets.
+letflat (Let (NonRec b (Let binds  res')) res) = 
+  change $ Let binds (Let (NonRec b res') res)
 letflat (Let (Rec binds) expr) = do
 letflat (Let (Rec binds) expr) = do
-  -- Turn each binding into a list of bindings (possibly containing just one
-  -- element, of course)
-  bindss <- Monad.mapM flatbind binds
-  -- Concat all the bindings
-  let binds' = concat bindss
+  -- Flatten each binding.
+  binds' <- Utils.concatM $ Monad.mapM flatbind binds
   -- Return the new let. We don't use change here, since possibly nothing has
   -- changed. If anything has changed, flatbind has already flagged that
   -- change.
   -- Return the new let. We don't use change here, since possibly nothing has
   -- changed. If anything has changed, flatbind has already flagged that
   -- change.
@@ -141,18 +223,86 @@ letflat (Let (Rec binds) expr) = do
     -- into a list with just that binding
     flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
     flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
     -- into a list with just that binding
     flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
     flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
+    flatbind (b, Let (NonRec b' expr') expr) = change [(b, expr), (b', expr')]
     flatbind (b, expr) = return [(b, expr)]
 -- Leave all other expressions unchanged
 letflat expr = return expr
 -- Perform this transform everywhere
 letflattop = everywhere ("letflat", letflat)
 
     flatbind (b, expr) = return [(b, expr)]
 -- Leave all other expressions unchanged
 letflat expr = return expr
 -- Perform this transform everywhere
 letflattop = everywhere ("letflat", letflat)
 
+--------------------------------
+-- empty let removal
+--------------------------------
+-- Remove empty (recursive) lets
+letremove, letremovetop :: Transform
+letremove (Let (Rec []) res) = change $ res
+-- Leave all other expressions unchanged
+letremove expr = return expr
+-- Perform this transform everywhere
+letremovetop = everywhere ("letremove", letremove)
+
 --------------------------------
 -- Simple let binding removal
 --------------------------------
 -- Remove a = b bindings from let expressions everywhere
 --------------------------------
 -- Simple let binding removal
 --------------------------------
 -- Remove a = b bindings from let expressions everywhere
-letremovetop :: Transform
-letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
+letremovesimpletop :: Transform
+letremovesimpletop = everywhere ("letremovesimple", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
+
+--------------------------------
+-- Unused let binding removal
+--------------------------------
+letremoveunused, letremoveunusedtop :: Transform
+letremoveunused expr@(Let (NonRec b bound) res) = do
+  let used = expr_uses_binders [b] res
+  if used
+    then return expr
+    else change res
+letremoveunused expr@(Let (Rec binds) res) = do
+  -- Filter out all unused binds.
+  let binds' = filter dobind binds
+  -- Only set the changed flag if binds got removed
+  changeif (length binds' /= length binds) (Let (Rec binds') res)
+    where
+      bound_exprs = map snd binds
+      -- For each bind check if the bind is used by res or any of the bound
+      -- expressions
+      dobind (bndr, _) = any (expr_uses_binders [bndr]) (res:bound_exprs)
+-- Leave all other expressions unchanged
+letremoveunused expr = return expr
+letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
+
+{-
+--------------------------------
+-- Identical let binding merging
+--------------------------------
+-- Merge two bindings in a let if they are identical 
+-- TODO: We would very much like to use GHC's CSE module for this, but that
+-- doesn't track if something changed or not, so we can't use it properly.
+letmerge, letmergetop :: Transform
+letmerge expr@(Let _ _) = do
+  let (binds, res) = flattenLets expr
+  binds' <- domerge binds
+  return $ mkNonRecLets binds' res
+  where
+    domerge :: [(CoreBndr, CoreExpr)] -> TransformMonad [(CoreBndr, CoreExpr)]
+    domerge [] = return []
+    domerge (e:es) = do 
+      es' <- mapM (mergebinds e) es
+      es'' <- domerge es'
+      return (e:es'')
+
+    -- Uses the second bind to simplify the second bind, if applicable.
+    mergebinds :: (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
+    mergebinds (b1, e1) (b2, e2)
+      -- Identical expressions? Replace the second binding with a reference to
+      -- the first binder.
+      | CoreUtils.cheapEqExpr e1 e2 = change $ (b2, Var b1)
+      -- Different expressions? Don't change
+      | otherwise = return (b2, e2)
+-- Leave all other expressions unchanged
+letmerge expr = return expr
+letmergetop = everywhere ("letmerge", letmerge)
+-}
 
 --------------------------------
 -- Function inlining
 
 --------------------------------
 -- Function inlining
@@ -172,6 +322,40 @@ letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_l
 inlinenonreptop :: Transform
 inlinenonreptop = everywhere ("inlinenonrep", inlinebind ((Monad.liftM not) . isRepr . snd))
 
 inlinenonreptop :: Transform
 inlinenonreptop = everywhere ("inlinenonrep", inlinebind ((Monad.liftM not) . isRepr . snd))
 
+inlinetoplevel, inlinetopleveltop :: Transform
+-- Any system name is candidate for inlining. Never inline user-defined
+-- functions, to preserve structure.
+inlinetoplevel expr@(Var f) | not $ isUserDefined f = do
+  norm <- isNormalizeable f
+  -- See if this is a top level binding for which we have a body
+  body_maybe <- Trans.lift $ getGlobalBind f
+  if norm && Maybe.isJust body_maybe
+    then do
+      -- Get the normalized version
+      norm <- Trans.lift $ getNormalized f
+      if needsInline norm 
+        then do
+          -- Regenerate all uniques in the to-be-inlined expression
+          norm_uniqued <- Trans.lift $ genUniques norm
+          change norm_uniqued
+        else
+          return expr
+    else
+      -- No body or not normalizeable.
+      return expr
+-- Leave all other expressions unchanged
+inlinetoplevel expr = return expr
+inlinetopleveltop = everywhere ("inlinetoplevel", inlinetoplevel)
+
+needsInline :: CoreExpr -> Bool
+needsInline expr = case splitNormalized expr of
+  -- Inline any function that only has a single definition, it is probably
+  -- simple enough. This might inline some stuff that it shouldn't though it
+  -- will never inline user-defined functions (inlinetoplevel only tries
+  -- system names) and inlining should never break things.
+  (args, [bind], res) -> True
+  _ -> False
+
 --------------------------------
 -- Scrutinee simplification
 --------------------------------
 --------------------------------
 -- Scrutinee simplification
 --------------------------------
@@ -186,8 +370,8 @@ scrutsimpl expr@(Case scrut b ty alts) = do
   repr <- isRepr scrut
   if repr
     then do
   repr <- isRepr scrut
   if repr
     then do
-      id <- mkInternalVar "scrut" (CoreUtils.exprType scrut)
-      change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
+      id <- Trans.lift $ mkBinderFor scrut "scrut"
+      change $ Let (NonRec id scrut) (Case (Var id) b ty alts)
     else
       return expr
 -- Leave all other expressions unchanged
     else
       return expr
 -- Leave all other expressions unchanged
@@ -213,7 +397,7 @@ casesimpl expr@(Case scrut b ty alts) = do
   (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts
   let bindings = concat bindingss
   -- Replace the case with a let with bindings and a case
   (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts
   let bindings = concat bindingss
   -- Replace the case with a let with bindings and a case
-  let newlet = (Let (Rec bindings) (Case scrut b ty alts'))
+  let newlet = mkNonRecLets bindings (Case scrut b ty alts')
   -- If there are no non-wild binders, or this case is already a simple
   -- selector (i.e., a single alt with exactly one binding), already a simple
   -- selector altan no bindings (i.e., no wild binders in the original case),
   -- If there are no non-wild binders, or this case is already a simple
   -- selector (i.e., a single alt with exactly one binding), already a simple
   -- selector altan no bindings (i.e., no wild binders in the original case),
@@ -236,7 +420,7 @@ casesimpl expr@(Case scrut b ty alts) = do
     (exprbinding_maybe, expr') <- doexpr expr uses_bndrs
     -- Create a new alternative
     let newalt = (con, newbndrs, expr')
     (exprbinding_maybe, expr') <- doexpr expr uses_bndrs
     -- Create a new alternative
     let newalt = (con, newbndrs, expr')
-    let bindings = Maybe.catMaybes (exprbinding_maybe : bindings_maybe)
+    let bindings = Maybe.catMaybes (bindings_maybe ++ [exprbinding_maybe])
     return (bindings, newalt)
     where
       -- Make wild alternatives for each binder
     return (bindings, newalt)
     where
       -- Make wild alternatives for each binder
@@ -248,7 +432,7 @@ casesimpl expr@(Case scrut b ty alts) = do
       -- binding containing a case expression.
       dobndr :: CoreBndr -> Int -> TransformMonad (CoreBndr, Maybe (CoreBndr, CoreExpr))
       dobndr b i = do
       -- binding containing a case expression.
       dobndr :: CoreBndr -> Int -> TransformMonad (CoreBndr, Maybe (CoreBndr, CoreExpr))
       dobndr b i = do
-        repr <- isRepr (Var b)
+        repr <- isRepr b
         -- Is b wild (e.g., not a free var of expr. Since b is only in scope
         -- in expr, this means that b is unused if expr does not use it.)
         let wild = not (VarSet.elemVarSet b free_vars)
         -- Is b wild (e.g., not a free var of expr. Since b is only in scope
         -- in expr, this means that b is unused if expr does not use it.)
         let wild = not (VarSet.elemVarSet b free_vars)
@@ -260,7 +444,7 @@ casesimpl expr@(Case scrut b ty alts) = do
             -- Create on new binder that will actually capture a value in this
             -- case statement, and return it.
             let bty = (Id.idType b)
             -- Create on new binder that will actually capture a value in this
             -- case statement, and return it.
             let bty = (Id.idType b)
-            id <- mkInternalVar "sel" bty
+            id <- Trans.lift $ mkInternalVar "sel" bty
             let binders = take i wildbndrs ++ [id] ++ drop (i+1) wildbndrs
             let caseexpr = Case scrut b bty [(con, binders, Var id)]
             return (wildbndrs!!i, Just (b, caseexpr))
             let binders = take i wildbndrs ++ [id] ++ drop (i+1) wildbndrs
             let caseexpr = Case scrut b bty [(con, binders, Var id)]
             return (wildbndrs!!i, Just (b, caseexpr))
@@ -280,7 +464,7 @@ casesimpl expr@(Case scrut b ty alts) = do
         -- prevent loops with inlinenonrep).
         if (not uses_bndrs) && (not local_var) && repr
           then do
         -- prevent loops with inlinenonrep).
         if (not uses_bndrs) && (not local_var) && repr
           then do
-            id <- mkInternalVar "caseval" (CoreUtils.exprType expr)
+            id <- Trans.lift $ mkBinderFor expr "caseval"
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
             return $ (Just (id, expr), Var id)
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
             return $ (Just (id, expr), Var id)
@@ -320,8 +504,8 @@ appsimpl expr@(App f arg) = do
   local_var <- Trans.lift $ is_local_var arg
   if repr && not local_var
     then do -- Extract representable arguments
   local_var <- Trans.lift $ is_local_var arg
   if repr && not local_var
     then do -- Extract representable arguments
-      id <- mkInternalVar "arg" (CoreUtils.exprType arg)
-      change $ Let (Rec [(id, arg)]) (App f (Var id))
+      id <- Trans.lift $ mkBinderFor arg "arg"
+      change $ Let (NonRec id arg) (App f (Var id))
     else -- Leave non-representable arguments unchanged
       return expr
 -- Leave all other expressions unchanged
     else -- Leave non-representable arguments unchanged
       return expr
 -- Leave all other expressions unchanged
@@ -356,7 +540,7 @@ argprop expr@(App _ _) | is_var fexpr = do
           -- the old body applied to some arguments.
           let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
           -- Create a new function with the same name but a new body
           -- the old body applied to some arguments.
           let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
           -- Create a new function with the same name but a new body
-          newf <- mkFunction f newbody
+          newf <- Trans.lift $ mkFunction f newbody
           -- Replace the original application with one of the new function to the
           -- new arguments.
           change $ MkCore.mkCoreApps (Var newf) newargs
           -- Replace the original application with one of the new function to the
           -- new arguments.
           change $ MkCore.mkCoreApps (Var newf) newargs
@@ -402,7 +586,7 @@ argprop expr@(App _ _) | is_var fexpr = do
           -- Representable types will not be propagated, and arguments with free
           -- type variables will be propagated later.
           -- TODO: preserve original naming?
           -- Representable types will not be propagated, and arguments with free
           -- type variables will be propagated later.
           -- TODO: preserve original naming?
-          id <- mkBinderFor arg "param"
+          id <- Trans.lift $ mkBinderFor arg "param"
           -- Just pass the original argument to the new function, which binds it
           -- to a new id and just pass that new id to the old function body.
           return ([arg], [id], mkReferenceTo id) 
           -- Just pass the original argument to the new function, which binds it
           -- to a new id and just pass that new id to the old function body.
           return ([arg], [id], mkReferenceTo id) 
@@ -449,7 +633,7 @@ funextract expr@(App _ _) | is_var fexpr = do
       -- by the argument expression.
       let free_vars = VarSet.varSetElems $ CoreFVs.exprFreeVars arg
       let body = MkCore.mkCoreLams free_vars arg
       -- by the argument expression.
       let free_vars = VarSet.varSetElems $ CoreFVs.exprFreeVars arg
       let body = MkCore.mkCoreLams free_vars arg
-      id <- mkBinderFor body "fun"
+      id <- Trans.lift $ mkBinderFor body "fun"
       Trans.lift $ addGlobalBind id body
       -- Replace the argument with a reference to the new function, applied to
       -- all vars it uses.
       Trans.lift $ addGlobalBind id body
       -- Replace the argument with a reference to the new function, applied to
       -- all vars it uses.
@@ -462,6 +646,25 @@ funextract expr = return expr
 -- Perform this transform everywhere
 funextracttop = everywhere ("funextract", funextract)
 
 -- Perform this transform everywhere
 funextracttop = everywhere ("funextract", funextract)
 
+--------------------------------
+-- Ensure that a function that just returns another function (or rather,
+-- another top-level binder) is still properly normalized. This is a temporary
+-- solution, we should probably integrate this pass with lambdasimpl and
+-- letsimpl instead.
+--------------------------------
+simplrestop expr@(Lam _ _) = return expr
+simplrestop expr@(Let _ _) = return expr
+simplrestop expr = do
+  local_var <- Trans.lift $ is_local_var expr
+  -- Don't extract values that are not representable, to prevent loops with
+  -- inlinenonrep
+  repr <- isRepr expr
+  if local_var || not repr
+    then
+      return expr
+    else do
+      id <- Trans.lift $ mkBinderFor expr "res" 
+      change $ Let (NonRec id expr) (Var id)
 --------------------------------
 -- End of transformations
 --------------------------------
 --------------------------------
 -- End of transformations
 --------------------------------
@@ -470,80 +673,57 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop]
-
--- Turns the given bind into VHDL
-normalizeModule ::
-  HscTypes.HscEnv
-  -> UniqSupply.UniqSupply -- ^ A UniqSupply we can use
-  -> [(CoreBndr, CoreExpr)]  -- ^ All bindings we know (i.e., in the current module)
-  -> [CoreExpr]
-  -> [CoreBndr]  -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
-  -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
-  -> ([(CoreBndr, CoreExpr)], [(CoreBndr, CoreExpr)], TypeState) -- ^ The resulting VHDL
-
-normalizeModule env uniqsupply bindings testexprs generate_for statefuls = runTransformSession env uniqsupply $ do
-  testbinds <- mapM (\x -> do { v <- mkBinderFor' x "test" ; return (v,x) } ) testexprs
-  let testbinders = (map fst testbinds)
-  -- Put all the bindings in this module in the tsBindings map
-  putA tsBindings (Map.fromList (bindings ++ testbinds))
-  -- (Recursively) normalize each of the requested bindings
-  mapM normalizeBind (generate_for ++ testbinders)
-  -- Get all initial bindings and the ones we produced
-  bindings_map <- getA tsBindings
-  let bindings = Map.assocs bindings_map
-  normalized_binders' <- getA tsNormalized
-  let normalized_binders = VarSet.delVarSetList normalized_binders' testbinders
-  let ret_testbinds = zip testbinders (Maybe.catMaybes $ map (\x -> lookup x bindings) testbinders)
-  let ret_binds = filter ((`VarSet.elemVarSet` normalized_binders) . fst) bindings
-  typestate <- getA tsType
-  -- But return only the normalized bindings
-  return $ (ret_binds, ret_testbinds, typestate)
-
-normalizeBind :: CoreBndr -> TransformSession ()
-normalizeBind bndr =
-  -- Don't normalize global variables, these should be either builtin
-  -- functions or data constructors.
-  Monad.when (Var.isLocalId bndr) $ do
-    -- Skip binders that have a polymorphic type, since it's impossible to
-    -- create polymorphic hardware.
-    if is_poly (Var bndr)
-      then
-        -- This should really only happen at the top level... TODO: Give
-        -- a different error if this happens down in the recursion.
-        error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
-      else do
-        normalized_funcs <- getA tsNormalized
-        -- See if this function was normalized already
-        if VarSet.elemVarSet bndr normalized_funcs
-          then
-            -- Yup, don't do it again
-            return ()
-          else do
-            -- Nope, note that it has been and do it.
-            modA tsNormalized (flip VarSet.extendVarSet bndr)
-            expr_maybe <- getGlobalBind bndr
-            case expr_maybe of 
-              Just expr -> do
-                -- Introduce an empty Let at the top level, so there will always be
-                -- a let in the expression (none of the transformations will remove
-                -- the last let).
-                let expr' = Let (Rec []) expr
-                -- Normalize this expression
-                trace ("Transforming " ++ (show bndr) ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
-                expr' <- dotransforms transforms expr'
-                trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
-                -- And store the normalized version in the session
-                modA tsBindings (Map.insert bndr expr')
-                -- Find all vars used with a function type. All of these should be global
-                -- binders (i.e., functions used), since any local binders with a function
-                -- type should have been inlined already.
-                bndrs <- getGlobalBinders
-                let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> not (Id.isDictId v) && v `elem` bndrs) expr'
-                let used_funcs = VarSet.varSetElems used_funcs_set
-                -- Process each of the used functions recursively
-                mapM normalizeBind used_funcs
-                return ()
-              -- We don't have a value for this binder. This really shouldn't
-              -- happen for local id's...
-              Nothing -> error $ "\nNormalize.normalizeBind: No value found for binder " ++ pprString bndr ++ "? This should not happen!"
+transforms = [inlinetopleveltop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
+
+-- | Returns the normalized version of the given function.
+getNormalized ::
+  CoreBndr -- ^ The function to get
+  -> TranslatorSession CoreExpr -- The normalized function body
+
+getNormalized bndr = Utils.makeCached bndr tsNormalized $ do
+  if is_poly (Var bndr)
+    then
+      -- This should really only happen at the top level... TODO: Give
+      -- a different error if this happens down in the recursion.
+      error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
+    else do
+      expr <- getBinding bndr
+      normalizeExpr (show bndr) expr
+
+-- | Normalize an expression
+normalizeExpr ::
+  String -- ^ What are we normalizing? For debug output only.
+  -> CoreSyn.CoreExpr -- ^ The expression to normalize 
+  -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression
+
+normalizeExpr what expr = do
+      expr_uniqued <- genUniques expr
+      -- Normalize this expression
+      trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr_uniqued ) ++ "\n") $ return ()
+      expr' <- dotransforms transforms expr_uniqued
+      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+      return expr'
+
+-- | Get the value that is bound to the given binder at top level. Fails when
+--   there is no such binding.
+getBinding ::
+  CoreBndr -- ^ The binder to get the expression for
+  -> TranslatorSession CoreExpr -- ^ The value bound to the binder
+
+getBinding bndr = Utils.makeCached bndr tsBindings $ do
+  -- If the binding isn't in the "cache" (bindings map), then we can't create
+  -- it out of thin air, so return an error.
+  error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
+
+-- | Split a normalized expression into the argument binders, top level
+--   bindings and the result binder.
+splitNormalized ::
+  CoreExpr -- ^ The normalized expression
+  -> ([CoreBndr], [Binding], CoreBndr)
+splitNormalized expr = (args, binds, res)
+  where
+    (args, letexpr) = CoreSyn.collectBinders expr
+    (binds, resexpr) = flattenLets letexpr
+    res = case resexpr of 
+      (Var x) -> x
+      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"