Add dictionary inlining transformation.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 4f02800547b729e7cc405c5d4990c366c284b4cb..4c04e8557553d4044944749b0b49a85c699e779a 100644 (file)
@@ -13,32 +13,25 @@ import qualified List
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
-import qualified Data.Map as Map
+import qualified Data.Accessor.Monad.Trans.State as MonadState
 import qualified Data.Monoid as Monoid
 import qualified Data.Monoid as Monoid
-import Data.Accessor
+import qualified Data.Map as Map
 
 -- GHC API
 import CoreSyn
 
 -- GHC API
 import CoreSyn
-import qualified UniqSupply
 import qualified CoreUtils
 import qualified Type
 import qualified CoreUtils
 import qualified Type
-import qualified TcType
-import qualified Name
 import qualified Id
 import qualified Var
 import qualified VarSet
 import qualified Id
 import qualified Var
 import qualified VarSet
-import qualified NameSet
 import qualified CoreFVs
 import qualified CoreFVs
-import qualified CoreUtils
 import qualified MkCore
 import qualified MkCore
-import qualified HscTypes
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
 import CLasH.Translator.TranslatorTypes
 import CLasH.Normalize.NormalizeTools
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
 import CLasH.Translator.TranslatorTypes
 import CLasH.Normalize.NormalizeTools
-import CLasH.VHDL.VHDLTypes
 import qualified CLasH.Utils as Utils
 import CLasH.Utils.Core.CoreTools
 import CLasH.Utils.Core.BinderTools
 import qualified CLasH.Utils as Utils
 import CLasH.Utils.Core.CoreTools
 import CLasH.Utils.Core.BinderTools
@@ -235,7 +228,7 @@ letflattop = everywhere ("letflat", letflat)
 --------------------------------
 -- Remove empty (recursive) lets
 letremove, letremovetop :: Transform
 --------------------------------
 -- Remove empty (recursive) lets
 letremove, letremovetop :: Transform
-letremove (Let (Rec []) res) = change res
+letremove (Let (Rec []) res) = change res
 -- Leave all other expressions unchanged
 letremove expr = return expr
 -- Perform this transform everywhere
 -- Leave all other expressions unchanged
 letremove expr = return expr
 -- Perform this transform everywhere
@@ -345,23 +338,19 @@ inlinetoplevel, inlinetopleveltop :: Transform
 -- Any system name is candidate for inlining. Never inline user-defined
 -- functions, to preserve structure.
 inlinetoplevel expr@(Var f) | not $ isUserDefined f = do
 -- Any system name is candidate for inlining. Never inline user-defined
 -- functions, to preserve structure.
 inlinetoplevel expr@(Var f) | not $ isUserDefined f = do
-  norm <- isNormalizeable f
-  -- See if this is a top level binding for which we have a body
-  body_maybe <- Trans.lift $ getGlobalBind f
-  if norm && Maybe.isJust body_maybe
-    then do
-      -- Get the normalized version
-      norm <- Trans.lift $ getNormalized f
-      if needsInline norm 
-        then do
-          -- Regenerate all uniques in the to-be-inlined expression
-          norm_uniqued <- Trans.lift $ genUniques norm
-          change norm_uniqued
-        else
-          return expr
-    else
+  norm_maybe <- Trans.lift $ getNormalized_maybe f
+  case norm_maybe of
       -- No body or not normalizeable.
       -- No body or not normalizeable.
-      return expr
+    Nothing -> return expr
+    Just norm -> if needsInline norm then do
+        -- Regenerate all uniques in the to-be-inlined expression
+        norm_uniqued <- Trans.lift $ genUniques norm
+        -- And replace the variable reference with the unique'd body.
+        change norm_uniqued
+      else
+        -- No need to inline
+        return expr
+
 -- Leave all other expressions unchanged
 inlinetoplevel expr = return expr
 inlinetopleveltop = everywhere ("inlinetoplevel", inlinetoplevel)
 -- Leave all other expressions unchanged
 inlinetoplevel expr = return expr
 inlinetopleveltop = everywhere ("inlinetoplevel", inlinetoplevel)
@@ -375,6 +364,22 @@ needsInline expr = case splitNormalized expr of
   (args, [bind], res) -> True
   _ -> False
 
   (args, [bind], res) -> True
   _ -> False
 
+
+--------------------------------
+-- Dictionary inlining
+--------------------------------
+-- Inline all top level dictionaries, so we can use them to resolve
+-- class methods based on the dictionary passed. 
+inlinedict expr@(Var f) | Id.isDictId f = do
+  body_maybe <- Trans.lift $ getGlobalBind f
+  case body_maybe of
+    Nothing -> return expr
+    Just body -> change body
+
+-- Leave all other expressions unchanged
+inlinedict expr = return expr
+inlinedicttop = everywhere ("inlinedict", inlinedict)
+
 --------------------------------
 -- Scrutinee simplification
 --------------------------------
 --------------------------------
 -- Scrutinee simplification
 --------------------------------
@@ -411,7 +416,7 @@ scrutbndrremove, scrutbndrremovetop :: Transform
 -- all occurences of the binder with the scrutinee variable.
 scrutbndrremove (Case (Var scrut) bndr ty alts) | bndr_used = do
     alts' <- mapM subs_bndr alts
 -- all occurences of the binder with the scrutinee variable.
 scrutbndrremove (Case (Var scrut) bndr ty alts) | bndr_used = do
     alts' <- mapM subs_bndr alts
-    return $ Case (Var scrut) wild ty alts'
+    change $ Case (Var scrut) wild ty alts'
   where
     is_used (_, _, expr) = expr_uses_binders [bndr] expr
     bndr_used = or $ map is_used alts
   where
     is_used (_, _, expr) = expr_uses_binders [bndr] expr
     bndr_used = or $ map is_used alts
@@ -419,6 +424,8 @@ scrutbndrremove (Case (Var scrut) bndr ty alts) | bndr_used = do
       expr' <- substitute bndr (Var scrut) expr
       return (con, bndrs, expr')
     wild = MkCore.mkWildBinder (Id.idType bndr)
       expr' <- substitute bndr (Var scrut) expr
       return (con, bndrs, expr')
     wild = MkCore.mkWildBinder (Id.idType bndr)
+-- Leave all other expressions unchanged
+scrutbndrremove expr = return expr
 scrutbndrremovetop = everywhere ("scrutbndrremove", scrutbndrremove)
 
 --------------------------------
 scrutbndrremovetop = everywhere ("scrutbndrremove", scrutbndrremove)
 
 --------------------------------
@@ -458,7 +465,7 @@ casesimpl expr@(Case scrut b ty alts) = do
     -- Extract a complex expression, if possible. For this we check if any of
     -- the new list of bndrs are used by expr. We can't use free_vars here,
     -- since that looks at the old bndrs.
     -- Extract a complex expression, if possible. For this we check if any of
     -- the new list of bndrs are used by expr. We can't use free_vars here,
     -- since that looks at the old bndrs.
-    let uses_bndrs = not $ VarSet.isEmptyVarSet $ CoreFVs.exprSomeFreeVars (`elem` newbndrs) expr
+    let uses_bndrs = not $ VarSet.isEmptyVarSet $ CoreFVs.exprSomeFreeVars (`elem` newbndrs) expr
     (exprbinding_maybe, expr') <- doexpr expr uses_bndrs
     -- Create a new alternative
     let newalt = (con, newbndrs, expr')
     (exprbinding_maybe, expr') <- doexpr expr uses_bndrs
     -- Create a new alternative
     let newalt = (con, newbndrs, expr')
@@ -509,7 +516,7 @@ casesimpl expr@(Case scrut b ty alts) = do
             id <- Trans.lift $ mkBinderFor expr "caseval"
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
             id <- Trans.lift $ mkBinderFor expr "caseval"
             -- We don't flag a change here, since casevalsimpl will do that above
             -- based on Just we return here.
-            return (Just (id, expr), Var id)
+            return (Just (id, expr), Var id)
           else
             -- Don't simplify anything else
             return (Nothing, expr)
           else
             -- Don't simplify anything else
             return (Nothing, expr)
@@ -583,6 +590,12 @@ argprop expr@(App _ _) | is_var fexpr = do
           let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
           -- Create a new function with the same name but a new body
           newf <- Trans.lift $ mkFunction f newbody
           let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
           -- Create a new function with the same name but a new body
           newf <- Trans.lift $ mkFunction f newbody
+
+          Trans.lift $ MonadState.modify tsInitStates (\ismap ->
+            let init_state_maybe = Map.lookup f ismap in
+            case init_state_maybe of
+              Nothing -> ismap
+              Just init_state -> Map.insert newf init_state ismap)
           -- Replace the original application with one of the new function to the
           -- new arguments.
           change $ MkCore.mkCoreApps (Var newf) newargs
           -- Replace the original application with one of the new function to the
           -- new arguments.
           change $ MkCore.mkCoreApps (Var newf) newargs
@@ -607,7 +620,7 @@ argprop expr@(App _ _) | is_var fexpr = do
     doarg arg = do
       repr <- isRepr arg
       bndrs <- Trans.lift getGlobalBinders
     doarg arg = do
       repr <- isRepr arg
       bndrs <- Trans.lift getGlobalBinders
-      let interesting var = Var.isLocalVar var && (not $ var `elem` bndrs)
+      let interesting var = Var.isLocalVar var && (var `notElem` bndrs)
       if not repr && not (is_var arg && interesting (exprToVar arg)) && not (has_free_tyvars arg) 
         then do
           -- Propagate all complex arguments that are not representable, but not
       if not repr && not (is_var arg && interesting (exprToVar arg)) && not (has_free_tyvars arg) 
         then do
           -- Propagate all complex arguments that are not representable, but not
@@ -623,10 +636,18 @@ argprop expr@(App _ _) | is_var fexpr = do
           let free_vars = VarSet.varSetElems $ CoreFVs.exprSomeFreeVars interesting arg
           -- Mark the current expression as changed
           setChanged
           let free_vars = VarSet.varSetElems $ CoreFVs.exprSomeFreeVars interesting arg
           -- Mark the current expression as changed
           setChanged
+          -- TODO: Clone the free_vars (and update references in arg), since
+          -- this might cause conflicts if two arguments that are propagated
+          -- share a free variable. Also, we are now introducing new variables
+          -- into a function that are not fresh, which violates the binder
+          -- uniqueness invariant.
           return (map Var free_vars, free_vars, arg)
         else do
           -- Representable types will not be propagated, and arguments with free
           -- type variables will be propagated later.
           return (map Var free_vars, free_vars, arg)
         else do
           -- Representable types will not be propagated, and arguments with free
           -- type variables will be propagated later.
+          -- Note that we implicitly remove any type variables in the type of
+          -- the original argument by using the type of the actual argument
+          -- for the new formal parameter.
           -- TODO: preserve original naming?
           id <- Trans.lift $ mkBinderFor arg "param"
           -- Just pass the original argument to the new function, which binds it
           -- TODO: preserve original naming?
           id <- Trans.lift $ mkBinderFor arg "param"
           -- Just pass the original argument to the new function, which binds it
@@ -715,22 +736,43 @@ simplrestop expr = do
 
 
 -- What transforms to run?
 
 
 -- What transforms to run?
-transforms = [inlinetopleveltop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
+transforms = [inlinedicttop, inlinetopleveltop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
 
 
--- | Returns the normalized version of the given function.
+-- | Returns the normalized version of the given function, or an error
+-- if it is not a known global binder.
 getNormalized ::
   CoreBndr -- ^ The function to get
   -> TranslatorSession CoreExpr -- The normalized function body
 getNormalized ::
   CoreBndr -- ^ The function to get
   -> TranslatorSession CoreExpr -- The normalized function body
-
-getNormalized bndr = Utils.makeCached bndr tsNormalized $ do
-  if is_poly (Var bndr)
-    then
-      -- This should really only happen at the top level... TODO: Give
-      -- a different error if this happens down in the recursion.
-      error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
-    else do
-      expr <- getBinding bndr
-      normalizeExpr (show bndr) expr
+getNormalized bndr = do
+  norm <- getNormalized_maybe bndr
+  return $ Maybe.fromMaybe
+    (error $ "Normalize.getNormalized: Unknown or non-representable function requested: " ++ show bndr)
+    norm
+
+-- | Returns the normalized version of the given function, or Nothing
+-- when the binder is not a known global binder or is not normalizeable.
+getNormalized_maybe ::
+  CoreBndr -- ^ The function to get
+  -> TranslatorSession (Maybe CoreExpr) -- The normalized function body
+
+getNormalized_maybe bndr = do
+    expr_maybe <- getGlobalBind bndr
+    normalizeable <- isNormalizeable' bndr
+    if not normalizeable || Maybe.isNothing expr_maybe
+      then
+        -- Binder not normalizeable or not found
+        return Nothing
+      else if is_poly (Var bndr)
+        then
+          -- This should really only happen at the top level... TODO: Give
+          -- a different error if this happens down in the recursion.
+          error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
+        else do
+          -- Binder found and is monomorphic. Normalize the expression
+          -- and cache the result.
+          normalized <- Utils.makeCached bndr tsNormalized $ 
+            normalizeExpr (show bndr) (Maybe.fromJust expr_maybe)
+          return (Just normalized)
 
 -- | Normalize an expression
 normalizeExpr ::
 
 -- | Normalize an expression
 normalizeExpr ::
@@ -746,17 +788,6 @@ normalizeExpr what expr = do
       trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return ()
       return expr'
 
       trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return ()
       return expr'
 
--- | Get the value that is bound to the given binder at top level. Fails when
---   there is no such binding.
-getBinding ::
-  CoreBndr -- ^ The binder to get the expression for
-  -> TranslatorSession CoreExpr -- ^ The value bound to the binder
-
-getBinding bndr = Utils.makeCached bndr tsBindings $ do
-  -- If the binding isn't in the "cache" (bindings map), then we can't create
-  -- it out of thin air, so return an error.
-  error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
-
 -- | Split a normalized expression into the argument binders, top level
 --   bindings and the result binder.
 splitNormalized ::
 -- | Split a normalized expression into the argument binders, top level
 --   bindings and the result binder.
 splitNormalized ::