Let scrutinee binder removal mark the expression as changed.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 40eab93c3787e27d9b599d939aeccd8e77264352..f0f2de2fccba3187092c4958bc1e3382cc0c0682 100644 (file)
@@ -13,32 +13,23 @@ import qualified List
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
-import qualified Data.Map as Map
 import qualified Data.Monoid as Monoid
 import qualified Data.Monoid as Monoid
-import Data.Accessor
 
 -- GHC API
 import CoreSyn
 
 -- GHC API
 import CoreSyn
-import qualified UniqSupply
 import qualified CoreUtils
 import qualified Type
 import qualified CoreUtils
 import qualified Type
-import qualified TcType
-import qualified Name
 import qualified Id
 import qualified Var
 import qualified VarSet
 import qualified Id
 import qualified Var
 import qualified VarSet
-import qualified NameSet
 import qualified CoreFVs
 import qualified CoreFVs
-import qualified CoreUtils
 import qualified MkCore
 import qualified MkCore
-import qualified HscTypes
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
 import CLasH.Translator.TranslatorTypes
 import CLasH.Normalize.NormalizeTools
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
 import CLasH.Translator.TranslatorTypes
 import CLasH.Normalize.NormalizeTools
-import CLasH.VHDL.VHDLTypes
 import qualified CLasH.Utils as Utils
 import CLasH.Utils.Core.CoreTools
 import CLasH.Utils.Core.BinderTools
 import qualified CLasH.Utils as Utils
 import CLasH.Utils.Core.CoreTools
 import CLasH.Utils.Core.BinderTools
@@ -235,7 +226,7 @@ letflattop = everywhere ("letflat", letflat)
 --------------------------------
 -- Remove empty (recursive) lets
 letremove, letremovetop :: Transform
 --------------------------------
 -- Remove empty (recursive) lets
 letremove, letremovetop :: Transform
-letremove (Let (Rec []) res) = change res
+letremove (Let (Rec []) res) = change res
 -- Leave all other expressions unchanged
 letremove expr = return expr
 -- Perform this transform everywhere
 -- Leave all other expressions unchanged
 letremove expr = return expr
 -- Perform this transform everywhere
@@ -398,6 +389,31 @@ scrutsimpl expr = return expr
 -- Perform this transform everywhere
 scrutsimpltop = everywhere ("scrutsimpl", scrutsimpl)
 
 -- Perform this transform everywhere
 scrutsimpltop = everywhere ("scrutsimpl", scrutsimpl)
 
+--------------------------------
+-- Scrutinee binder removal
+--------------------------------
+-- A case expression can have an extra binder, to which the scrutinee is bound
+-- after bringing it to WHNF. This is used for forcing evaluation of strict
+-- arguments. Since strictness does not matter for us (rather, everything is
+-- sort of strict), this binder is ignored when generating VHDL, and must thus
+-- be wild in the normal form.
+scrutbndrremove, scrutbndrremovetop :: Transform
+-- If the scrutinee is already simple, and the bndr is not wild yet, replace
+-- all occurences of the binder with the scrutinee variable.
+scrutbndrremove (Case (Var scrut) bndr ty alts) | bndr_used = do
+    alts' <- mapM subs_bndr alts
+    change $ Case (Var scrut) wild ty alts'
+  where
+    is_used (_, _, expr) = expr_uses_binders [bndr] expr
+    bndr_used = or $ map is_used alts
+    subs_bndr (con, bndrs, expr) = do
+      expr' <- substitute bndr (Var scrut) expr
+      return (con, bndrs, expr')
+    wild = MkCore.mkWildBinder (Id.idType bndr)
+-- Leave all other expressions unchanged
+scrutbndrremove expr = return expr
+scrutbndrremovetop = everywhere ("scrutbndrremove", scrutbndrremove)
+
 --------------------------------
 -- Case binder wildening
 --------------------------------
 --------------------------------
 -- Case binder wildening
 --------------------------------
@@ -435,7 +451,7 @@ casesimpl expr@(Case scrut b ty alts) = do
     -- Extract a complex expression, if possible. For this we check if any of
     -- the new list of bndrs are used by expr. We can't use free_vars here,
     -- since that looks at the old bndrs.
     -- Extract a complex expression, if possible. For this we check if any of
     -- the new list of bndrs are used by expr. We can't use free_vars here,
     -- since that looks at the old bndrs.
-    let uses_bndrs = not $ VarSet.isEmptyVarSet $ CoreFVs.exprSomeFreeVars (`elem` newbndrs) expr
+    let uses_bndrs = not $ VarSet.isEmptyVarSet $ CoreFVs.exprSomeFreeVars (`elem` newbndrs) expr
     (exprbinding_maybe, expr') <- doexpr expr uses_bndrs
     -- Create a new alternative
     let newalt = (con, newbndrs, expr')
     (exprbinding_maybe, expr') <- doexpr expr uses_bndrs
     -- Create a new alternative
     let newalt = (con, newbndrs, expr')
@@ -486,7 +502,7 @@ casesimpl expr@(Case scrut b ty alts) = do
             id <- Trans.lift $ mkBinderFor expr "caseval"
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
             id <- Trans.lift $ mkBinderFor expr "caseval"
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
-            return (Just (id, expr), Var id)
+            return (Just (id, expr), Var id)
           else
             -- Don't simplify anything else
             return (Nothing, expr)
           else
             -- Don't simplify anything else
             return (Nothing, expr)
@@ -584,7 +600,7 @@ argprop expr@(App _ _) | is_var fexpr = do
     doarg arg = do
       repr <- isRepr arg
       bndrs <- Trans.lift getGlobalBinders
     doarg arg = do
       repr <- isRepr arg
       bndrs <- Trans.lift getGlobalBinders
-      let interesting var = Var.isLocalVar var && (not $ var `elem` bndrs)
+      let interesting var = Var.isLocalVar var && (var `notElem` bndrs)
       if not repr && not (is_var arg && interesting (exprToVar arg)) && not (has_free_tyvars arg) 
         then do
           -- Propagate all complex arguments that are not representable, but not
       if not repr && not (is_var arg && interesting (exprToVar arg)) && not (has_free_tyvars arg) 
         then do
           -- Propagate all complex arguments that are not representable, but not
@@ -600,10 +616,18 @@ argprop expr@(App _ _) | is_var fexpr = do
           let free_vars = VarSet.varSetElems $ CoreFVs.exprSomeFreeVars interesting arg
           -- Mark the current expression as changed
           setChanged
           let free_vars = VarSet.varSetElems $ CoreFVs.exprSomeFreeVars interesting arg
           -- Mark the current expression as changed
           setChanged
+          -- TODO: Clone the free_vars (and update references in arg), since
+          -- this might cause conflicts if two arguments that are propagated
+          -- share a free variable. Also, we are now introducing new variables
+          -- into a function that are not fresh, which violates the binder
+          -- uniqueness invariant.
           return (map Var free_vars, free_vars, arg)
         else do
           -- Representable types will not be propagated, and arguments with free
           -- type variables will be propagated later.
           return (map Var free_vars, free_vars, arg)
         else do
           -- Representable types will not be propagated, and arguments with free
           -- type variables will be propagated later.
+          -- Note that we implicitly remove any type variables in the type of
+          -- the original argument by using the type of the actual argument
+          -- for the new formal parameter.
           -- TODO: preserve original naming?
           id <- Trans.lift $ mkBinderFor arg "param"
           -- Just pass the original argument to the new function, which binds it
           -- TODO: preserve original naming?
           id <- Trans.lift $ mkBinderFor arg "param"
           -- Just pass the original argument to the new function, which binds it
@@ -692,14 +716,14 @@ simplrestop expr = do
 
 
 -- What transforms to run?
 
 
 -- What transforms to run?
-transforms = [inlinetopleveltop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
+transforms = [inlinetopleveltop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
   CoreBndr -- ^ The function to get
   -> TranslatorSession CoreExpr -- The normalized function body
 
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
   CoreBndr -- ^ The function to get
   -> TranslatorSession CoreExpr -- The normalized function body
 
-getNormalized bndr = Utils.makeCached bndr tsNormalized $ do
+getNormalized bndr = Utils.makeCached bndr tsNormalized $
   if is_poly (Var bndr)
     then
       -- This should really only happen at the top level... TODO: Give
   if is_poly (Var bndr)
     then
       -- This should really only happen at the top level... TODO: Give
@@ -729,7 +753,7 @@ getBinding ::
   CoreBndr -- ^ The binder to get the expression for
   -> TranslatorSession CoreExpr -- ^ The value bound to the binder
 
   CoreBndr -- ^ The binder to get the expression for
   -> TranslatorSession CoreExpr -- ^ The value bound to the binder
 
-getBinding bndr = Utils.makeCached bndr tsBindings $ do
+getBinding bndr = Utils.makeCached bndr tsBindings $
   -- If the binding isn't in the "cache" (bindings map), then we can't create
   -- it out of thin air, so return an error.
   error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
   -- If the binding isn't in the "cache" (bindings map), then we can't create
   -- it out of thin air, so return an error.
   error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr