Don't propagate types with free tyvars.
[matthijs/master-project/cλash.git] / Normalize.hs
index 2b37e09a72719113798cb8cfef66fad7fca6f698..0d58d04a9d53c323494a046914047b0178b10eee 100644 (file)
@@ -1,14 +1,18 @@
+{-# LANGUAGE PackageImports #-}
 --
 -- Functions to bring a Core expression in normal form. This module provides a
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
-module Normalize (normalize) where
+module Normalize (normalizeModule) where
 
 -- Standard modules
 import Debug.Trace
 import qualified Maybe
+import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
+import qualified Data.Map as Map
+import Data.Accessor
 
 -- GHC API
 import CoreSyn
@@ -16,7 +20,8 @@ import qualified UniqSupply
 import qualified CoreUtils
 import qualified Type
 import qualified Id
-import qualified UniqSet
+import qualified Var
+import qualified VarSet
 import qualified CoreFVs
 import Outputable ( showSDoc, ppr, nest )
 
@@ -76,8 +81,8 @@ letsimpl, letsimpltop :: Transform
 -- Don't simplifiy lets that are already simple
 letsimpl expr@(Let _ (Var _)) = return expr
 -- Put the "in ..." value of a let in its own binding, but not when the
--- expression has a function type (to prevent loops with inlinefun).
-letsimpl (Let (Rec binds) expr) | not $ is_fun expr = do
+-- expression is applicable (to prevent loops with inlinefun).
+letsimpl (Let (Rec binds) expr) | not $ is_applicable expr = do
   id <- mkInternalVar "foo" (CoreUtils.exprType expr)
   let bind = (id, expr)
   change $ Let (Rec (bind:binds)) (Var id)
@@ -121,7 +126,12 @@ letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> case e of (Var v)
 --------------------------------
 -- Function inlining
 --------------------------------
--- Remove a = B bindings, with B :: a -> b, from let expressions everywhere.
+-- Remove a = B bindings, with B :: a -> b, or B :: forall x . T, from let
+-- expressions everywhere. This means that any value that still needs to be
+-- applied to something else (polymorphic values need to be applied to a
+-- Type) will be inlined, and will eventually be applied to all their
+-- arguments.
+--
 -- This is a tricky function, which is prone to create loops in the
 -- transformations. To fix this, we make sure that no transformation will
 -- create a new let binding with a function type. These other transformations
@@ -129,7 +139,7 @@ letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> case e of (Var v)
 -- transformations (in particular β-reduction) should make sure that the type
 -- of those values eventually becomes primitive.
 inlinefuntop :: Transform
-inlinefuntop = everywhere ("inlinefun", inlinebind (Type.isFunTy . CoreUtils.exprType . snd))
+inlinefuntop = everywhere ("inlinefun", inlinebind (is_applicable . snd))
 
 --------------------------------
 -- Scrutinee simplification
@@ -138,10 +148,10 @@ scrutsimpl,scrutsimpltop :: Transform
 -- Don't touch scrutinees that are already simple
 scrutsimpl expr@(Case (Var _) _ _ _) = return expr
 -- Replace all other cases with a let that binds the scrutinee and a new
--- simple scrutinee, but not when the scrutinee is a function type (to prevent
--- loops with inlinefun, though I don't think a scrutinee can have a function
--- type...)
-scrutsimpl (Case scrut b ty alts) | not $ is_fun scrut = do
+-- simple scrutinee, but not when the scrutinee is applicable (to prevent
+-- loops with inlinefun, though I don't think a scrutinee can be
+-- applicable...)
+scrutsimpl (Case scrut b ty alts) | not $ is_applicable scrut = do
   id <- mkInternalVar "scrut" (CoreUtils.exprType scrut)
   change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
 -- Leave all other expressions unchanged
@@ -223,15 +233,15 @@ casevalsimpl expr@(Case scrut b ty alts) = do
     -- replacing the case value with that id. Only do this when the case value
     -- does not use any of the binders bound by this alternative, for that would
     -- cause those binders to become unbound when moving the value outside of
-    -- the case statement. Also, don't create a binding for function-typed
+    -- the case statement. Also, don't create a binding for applicable
     -- expressions, to prevent loops with inlinefun.
-    doalt (con, bndrs, expr) | (not usesvars) && (not $ is_fun expr) = do
+    doalt (con, bndrs, expr) | (not usesvars) && (not $ is_applicable expr) = do
       id <- mkInternalVar "caseval" (CoreUtils.exprType expr)
       -- We don't flag a change here, since casevalsimpl will do that above
       -- based on Just we return here.
       return $ (Just (id, expr), (con, bndrs, Var id))
       -- Find if any of the binders are used by expr
-      where usesvars = (not . UniqSet.isEmptyUniqSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
+      where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
     -- Don't simplify anything else
     doalt alt = return (Nothing, alt)
 -- Leave all other expressions unchanged
@@ -248,7 +258,7 @@ caseremove, caseremovetop :: Transform
 -- Replace a useless case by the value of its single alternative
 caseremove (Case scrut b ty [(con, bndrs, expr)]) | not usesvars = change expr
     -- Find if any of the binders are used by expr
-    where usesvars = (not . UniqSet.isEmptyUniqSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
+    where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
 -- Leave all other expressions unchanged
 caseremove expr = return expr
 -- Perform this transform everywhere
@@ -261,10 +271,11 @@ caseremovetop = everywhere ("caseremove", caseremove)
 appsimpl, appsimpltop :: Transform
 -- Don't simplify arguments that are already simple
 appsimpl expr@(App f (Var _)) = return expr
--- Simplify all arguments that do not have a function type (to prevent loops
--- with inlinefun) and is not a type argument. Do this by introducing a new
--- Let that binds the argument and passing the new binder in the application.
-appsimpl (App f expr) | (not $ is_fun expr) && (not $ CoreSyn.isTypeArg expr) = do
+-- Simplify all non-applicable (to prevent loops with inlinefun) arguments,
+-- except for type arguments (since a let can't bind type vars, only a lambda
+-- can). Do this by introducing a new Let that binds the argument and passing
+-- the new binder in the application.
+appsimpl (App f expr) | (not $ is_applicable expr) && (not $ CoreSyn.isTypeArg expr) = do
   id <- mkInternalVar "arg" (CoreUtils.exprType expr)
   change $ Let (Rec [(id, expr)]) (App f (Var id))
 -- Leave all other expressions unchanged
@@ -272,6 +283,38 @@ appsimpl expr = return expr
 -- Perform this transform everywhere
 appsimpltop = everywhere ("appsimpl", appsimpl)
 
+
+--------------------------------
+-- Type argument propagation
+--------------------------------
+-- Remove all applications to type arguments, by duplicating the function
+-- called with the type application in its new definition. We leave
+-- dictionaries that might be associated with the type untouched, the funprop
+-- transform should propagate these later on.
+typeprop, typeproptop :: Transform
+-- Transform any function that is applied to a type argument. Since type
+-- arguments are always the first ones to apply and we'll remove all type
+-- arguments, we can simply do them one by one. We only propagate type
+-- arguments without any free tyvars, since tyvars those wouldn't be in scope
+-- in the new function.
+typeprop expr@(App (Var f) arg@(Type ty)) | not $ has_free_tyvars arg = do
+  id <- cloneVar f
+  let newty = Type.applyTy (Id.idType f) ty
+  let newf = Var.setVarType id newty
+  body_maybe <- Trans.lift $ getGlobalBind f
+  case body_maybe of
+    Just body -> do
+      let newbody = App body (Type ty)
+      Trans.lift $ addGlobalBind newf newbody
+      change (Var newf)
+    -- If we don't have a body for the function called, leave it unchanged (it
+    -- should be a primitive function then).
+    Nothing -> return expr
+-- Leave all other expressions unchanged
+typeprop expr = return expr
+-- Perform this transform everywhere
+typeproptop = everywhere ("typeprop", typeprop)
+
 -- TODO: introduce top level let if needed?
 
 --------------------------------
@@ -282,10 +325,56 @@ appsimpltop = everywhere ("appsimpl", appsimpl)
 
 
 -- What transforms to run?
-transforms = [etatop, betatop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinefuntop, appsimpltop]
+transforms = [typeproptop, etatop, betatop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinefuntop, appsimpltop]
+
+-- Turns the given bind into VHDL
+normalizeModule :: 
+  UniqSupply.UniqSupply -- ^ A UniqSupply we can use
+  -> [(CoreBndr, CoreExpr)]  -- ^ All bindings we know (i.e., in the current module)
+  -> [CoreBndr]  -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
+  -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
+  -> [(CoreBndr, CoreExpr)] -- ^ The resulting VHDL
 
--- Normalize a core expression by running transforms until none applies
--- anymore. Uses a UniqSupply to generate new identifiers.
-normalize :: UniqSupply.UniqSupply -> CoreExpr -> CoreExpr
-normalize = dotransforms transforms
+normalizeModule uniqsupply bindings generate_for statefuls = runTransformSession uniqsupply $ do
+  -- Put all the bindings in this module in the tsBindings map
+  putA tsBindings (Map.fromList bindings)
+  -- (Recursively) normalize each of the requested bindings
+  mapM normalizeBind generate_for
+  -- Get all initial bindings and the ones we produced
+  bindings_map <- getA tsBindings
+  let bindings = Map.assocs bindings_map
+  normalized_bindings <- getA tsNormalized
+  -- But return only the normalized bindings
+  return $ filter ((flip VarSet.elemVarSet normalized_bindings) . fst) bindings
 
+normalizeBind :: CoreBndr -> TransformSession ()
+normalizeBind bndr = do
+  normalized_funcs <- getA tsNormalized
+  -- See if this function was normalized already
+  if VarSet.elemVarSet bndr normalized_funcs
+    then
+      -- Yup, don't do it again
+      return ()
+    else do
+      -- Nope, note that it has been and do it.
+      modA tsNormalized (flip VarSet.extendVarSet bndr)
+      expr_maybe <- getGlobalBind bndr
+      case expr_maybe of 
+        Just expr -> do
+          -- Normalize this expression
+          expr' <- dotransforms transforms expr
+          let expr'' = trace ("Before:\n\n" ++ showSDoc ( ppr expr ) ++ "\n\nAfter:\n\n" ++ showSDoc ( ppr expr')) expr'
+          -- And store the normalized version in the session
+          modA tsBindings (Map.insert bndr expr'')
+          -- Find all vars used with a function type. All of these should be global
+          -- binders (i.e., functions used), since any local binders with a function
+          -- type should have been inlined already.
+          let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> (Type.isFunTy . snd . Type.splitForAllTys . Id.idType) v) expr''
+          let used_funcs = VarSet.varSetElems used_funcs_set
+          -- Process each of the used functions recursively
+          mapM normalizeBind used_funcs
+          return ()
+        -- We don't have a value for this binder, let's assume this is a builtin
+        -- function. This might need some extra checking and a nice error
+        -- message).
+        Nothing -> return ()