Moved to new GHC API (6.11). Also use vhdl package for the VHDL AST
[matthijs/master-project/cλash.git] / Normalize.hs
index 647168b4587df5f782fca4aef6e8ce998fcae508..12356e23c7b9af33a6a0c8989fa31efc27172ab6 100644 (file)
@@ -11,7 +11,9 @@ import Debug.Trace
 import qualified Maybe
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
+import qualified Control.Monad.Trans.Writer as Writer
 import qualified Data.Map as Map
+import qualified Data.Monoid as Monoid
 import Data.Accessor
 
 -- GHC API
@@ -19,16 +21,23 @@ import CoreSyn
 import qualified UniqSupply
 import qualified CoreUtils
 import qualified Type
+import qualified TcType
 import qualified Id
 import qualified Var
 import qualified VarSet
+import qualified NameSet
 import qualified CoreFVs
+import qualified CoreUtils
+import qualified MkCore
+import qualified HscTypes
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import NormalizeTypes
 import NormalizeTools
+import VHDLTypes
 import CoreTools
+import Pretty
 
 --------------------------------
 -- Start of transformations
@@ -44,7 +53,7 @@ eta expr | is_fun expr && not (is_lam expr) = do
   change (Lam id (App expr (Var id)))
 -- Leave all other expressions unchanged
 eta e = return e
-etatop = notapplied ("eta", eta)
+etatop = notappargs ("eta", eta)
 
 --------------------------------
 -- β-reduction
@@ -58,12 +67,26 @@ beta (App (Let binds expr) arg) = change $ Let binds (App expr arg)
 beta (App (Case scrut b ty alts) arg) = change $ Case scrut b ty' alts'
   where 
     alts' = map (\(con, bndrs, expr) -> (con, bndrs, (App expr arg))) alts
-    (_, ty') = Type.splitFunTy ty
+    ty' = CoreUtils.applyTypeToArg ty arg
 -- Leave all other expressions unchanged
 beta expr = return expr
 -- Perform this transform everywhere
 betatop = everywhere ("beta", beta)
 
+--------------------------------
+-- Cast propagation
+--------------------------------
+-- Try to move casts as much downward as possible.
+castprop, castproptop :: Transform
+castprop (Cast (Let binds expr) ty) = change $ Let binds (Cast expr ty)
+castprop expr@(Cast (Case scrut b _ alts) ty) = change (Case scrut b ty alts')
+  where
+    alts' = map (\(con, bndrs, expr) -> (con, bndrs, (Cast expr ty))) alts
+-- Leave all other expressions unchanged
+castprop expr = return expr
+-- Perform this transform everywhere
+castproptop = everywhere ("castprop", castprop)
+
 --------------------------------
 -- let recursification
 --------------------------------
@@ -78,14 +101,21 @@ letrectop = everywhere ("letrec", letrec)
 -- let simplification
 --------------------------------
 letsimpl, letsimpltop :: Transform
--- Don't simplifiy lets that are already simple
-letsimpl expr@(Let _ (Var _)) = return expr
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is applicable (to prevent loops with inlinefun).
-letsimpl (Let (Rec binds) expr) | not $ is_applicable expr = do
-  id <- mkInternalVar "foo" (CoreUtils.exprType expr)
-  let bind = (id, expr)
-  change $ Let (Rec (bind:binds)) (Var id)
+letsimpl expr@(Let (Rec binds) res) | not $ is_applicable expr = do
+  local_var <- Trans.lift $ is_local_var res
+  if not local_var
+    then do
+      -- If the result is not a local var already (to prevent loops with
+      -- ourselves), extract it.
+      id <- mkInternalVar "foo" (CoreUtils.exprType res)
+      let bind = (id, res)
+      change $ Let (Rec (bind:binds)) (Var id)
+    else
+      -- If the result is already a local var, don't extract it.
+      return expr
+
 -- Leave all other expressions unchanged
 letsimpl expr = return expr
 -- Perform this transform everywhere
@@ -121,7 +151,7 @@ letflattop = everywhere ("letflat", letflat)
 --------------------------------
 -- Remove a = b bindings from let expressions everywhere
 letremovetop :: Transform
-letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> case e of (Var v) -> True; otherwise -> False))
+letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
 
 --------------------------------
 -- Function inlining
@@ -138,8 +168,8 @@ letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> case e of (Var v)
 -- will just not work on those function-typed values at first, but the other
 -- transformations (in particular β-reduction) should make sure that the type
 -- of those values eventually becomes primitive.
-inlinefuntop :: Transform
-inlinefuntop = everywhere ("inlinefun", inlinebind (is_applicable . snd))
+inlinenonreptop :: Transform
+inlinenonreptop = everywhere ("inlinenonrep", inlinebind ((Monad.liftM not) . isRepr . snd))
 
 --------------------------------
 -- Scrutinee simplification
@@ -175,7 +205,7 @@ casewild expr@(Case scrut b ty alts) = do
   if null bindings || length alts == 1 && length bindings == 1 then return expr else change newlet 
   where
   -- Generate a single wild binder, since they are all the same
-  wild = Id.mkWildId
+  wild = MkCore.mkWildBinder
   -- Wilden the binders of one alt, producing a list of bindings as a
   -- sideeffect.
   doalt :: CoreAlt -> TransformMonad ([(CoreBndr, CoreExpr)], CoreAlt)
@@ -189,14 +219,17 @@ casewild expr@(Case scrut b ty alts) = do
     return (bindings, newalt)
     where
       -- Make all binders wild
-      wildbndrs = map (\bndr -> Id.mkWildId (Id.idType bndr)) bndrs
+      wildbndrs = map (\bndr -> MkCore.mkWildBinder (Id.idType bndr)) bndrs
+      -- A set of all the binders that are used by the expression
+      free_vars = CoreFVs.exprSomeFreeVars (`elem` bndrs) expr
       -- Creates a case statement to retrieve the ith element from the scrutinee
       -- and binds that to b.
       mkextracts :: CoreBndr -> Int -> TransformMonad (Maybe (CoreBndr, CoreExpr))
       mkextracts b i =
-        if is_wild b || Type.isFunTy (Id.idType b) 
-          -- Don't create extra bindings for binders that are already wild, or
-          -- for binders that bind function types (to prevent loops with
+        if not (VarSet.elemVarSet b free_vars) || Type.isFunTy (Id.idType b) 
+          -- Don't create extra bindings for binders that are already wild
+          -- (e.g. not in the free variables of expr, so unused), or for
+          -- binders that bind function types (to prevent loops with
           -- inlinefun).
           then return Nothing
           else do
@@ -265,57 +298,159 @@ caseremove expr = return expr
 caseremovetop = everywhere ("caseremove", caseremove)
 
 --------------------------------
--- Application simplification
+-- Argument extraction
 --------------------------------
--- Make sure that all arguments in an application are simple variables.
+-- Make sure that all arguments of a representable type are simple variables.
 appsimpl, appsimpltop :: Transform
--- Don't simplify arguments that are already simple
-appsimpl expr@(App f (Var _)) = return expr
--- Simplify all non-applicable (to prevent loops with inlinefun) arguments,
--- except for type arguments (since a let can't bind type vars, only a lambda
--- can). Do this by introducing a new Let that binds the argument and passing
--- the new binder in the application.
-appsimpl (App f expr) | (not $ is_applicable expr) && (not $ CoreSyn.isTypeArg expr) = do
-  id <- mkInternalVar "arg" (CoreUtils.exprType expr)
-  change $ Let (Rec [(id, expr)]) (App f (Var id))
+-- Simplify all representable arguments. Do this by introducing a new Let
+-- that binds the argument and passing the new binder in the application.
+appsimpl expr@(App f arg) = do
+  -- Check runtime representability
+  repr <- isRepr arg
+  local_var <- Trans.lift $ is_local_var arg
+  if repr && not local_var
+    then do -- Extract representable arguments
+      id <- mkInternalVar "arg" (CoreUtils.exprType arg)
+      change $ Let (Rec [(id, arg)]) (App f (Var id))
+    else -- Leave non-representable arguments unchanged
+      return expr
 -- Leave all other expressions unchanged
 appsimpl expr = return expr
 -- Perform this transform everywhere
 appsimpltop = everywhere ("appsimpl", appsimpl)
 
-
 --------------------------------
--- Type argument propagation
---------------------------------
--- Remove all applications to type arguments, by duplicating the function
--- called with the type application in its new definition. We leave
--- dictionaries that might be associated with the type untouched, the funprop
--- transform should propagate these later on.
-typeprop, typeproptop :: Transform
--- Transform any function that is applied to a type argument. Since type
--- arguments are always the first ones to apply and we'll remove all type
--- arguments, we can simply do them one by one. We only propagate type
--- arguments without any free tyvars, since tyvars those wouldn't be in scope
--- in the new function.
-typeprop expr@(App (Var f) arg@(Type ty)) | not $ has_free_tyvars arg = do
-  id <- cloneVar f
-  let newty = Type.applyTy (Id.idType f) ty
-  let newf = Var.setVarType id newty
+-- Function-typed argument propagation
+--------------------------------
+-- Remove all applications to function-typed arguments, by duplication the
+-- function called with the function-typed parameter replaced by the free
+-- variables of the argument passed in.
+argprop, argproptop :: Transform
+-- Transform any application of a named function (i.e., skip applications of
+-- lambda's). Also skip applications that have arguments with free type
+-- variables, since we can't inline those.
+argprop expr@(App _ _) | is_var fexpr = do
+  -- Find the body of the function called
   body_maybe <- Trans.lift $ getGlobalBind f
   case body_maybe of
     Just body -> do
-      let newbody = App body (Type ty)
-      Trans.lift $ addGlobalBind newf newbody
-      change (Var newf)
+      -- Process each of the arguments in turn
+      (args', changed) <- Writer.listen $ mapM doarg args
+      -- See if any of the arguments changed
+      case Monoid.getAny changed of
+        True -> do
+          let (newargs', newparams', oldargs) = unzip3 args'
+          let newargs = concat newargs'
+          let newparams = concat newparams'
+          -- Create a new body that consists of a lambda for all new arguments and
+          -- the old body applied to some arguments.
+          let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
+          -- Create a new function with the same name but a new body
+          newf <- mkFunction f newbody
+          -- Replace the original application with one of the new function to the
+          -- new arguments.
+          change $ MkCore.mkCoreApps (Var newf) newargs
+        False ->
+          -- Don't change the expression if none of the arguments changed
+          return expr
+      
     -- If we don't have a body for the function called, leave it unchanged (it
     -- should be a primitive function then).
     Nothing -> return expr
+  where
+    -- Find the function called and the arguments
+    (fexpr, args) = collectArgs expr
+    Var f = fexpr
+
+    -- Process a single argument and return (args, bndrs, arg), where args are
+    -- the arguments to replace the given argument in the original
+    -- application, bndrs are the binders to include in the top-level lambda
+    -- in the new function body, and arg is the argument to apply to the old
+    -- function body.
+    doarg :: CoreExpr -> TransformMonad ([CoreExpr], [CoreBndr], CoreExpr)
+    doarg arg = do
+      repr <- isRepr arg
+      bndrs <- Trans.lift getGlobalBinders
+      let interesting var = Var.isLocalVar var && (not $ var `elem` bndrs)
+      if not repr && not (is_var arg && interesting (exprToVar arg)) && not (has_free_tyvars arg) 
+        then do
+          -- Propagate all complex arguments that are not representable, but not
+          -- arguments with free type variables (since those would require types
+          -- not known yet, which will always be known eventually).
+          -- Find interesting free variables, each of which should be passed to
+          -- the new function instead of the original function argument.
+          -- 
+          -- Interesting vars are those that are local, but not available from the
+          -- top level scope (functions from this module are defined as local, but
+          -- they're not local to this function, so we can freely move references
+          -- to them into another function).
+          let free_vars = VarSet.varSetElems $ CoreFVs.exprSomeFreeVars interesting arg
+          -- Mark the current expression as changed
+          setChanged
+          return (map Var free_vars, free_vars, arg)
+        else do
+          -- Representable types will not be propagated, and arguments with free
+          -- type variables will be propagated later.
+          -- TODO: preserve original naming?
+          id <- mkBinderFor arg "param"
+          -- Just pass the original argument to the new function, which binds it
+          -- to a new id and just pass that new id to the old function body.
+          return ([arg], [id], mkReferenceTo id) 
 -- Leave all other expressions unchanged
-typeprop expr = return expr
+argprop expr = return expr
 -- Perform this transform everywhere
-typeproptop = everywhere ("typeprop", typeprop)
+argproptop = everywhere ("argprop", argprop)
+
+--------------------------------
+-- Function-typed argument extraction
+--------------------------------
+-- This transform takes any function-typed argument that cannot be propagated
+-- (because the function that is applied to it is a builtin function), and
+-- puts it in a brand new top level binder. This allows us to for example
+-- apply map to a lambda expression This will not conflict with inlinefun,
+-- since that only inlines local let bindings, not top level bindings.
+funextract, funextracttop :: Transform
+funextract expr@(App _ _) | is_var fexpr = do
+  body_maybe <- Trans.lift $ getGlobalBind f
+  case body_maybe of
+    -- We don't have a function body for f, so we can perform this transform.
+    Nothing -> do
+      -- Find the new arguments
+      args' <- mapM doarg args
+      -- And update the arguments. We use return instead of changed, so the
+      -- changed flag doesn't get set if none of the args got changed.
+      return $ MkCore.mkCoreApps fexpr args'
+    -- We have a function body for f, leave this application to funprop
+    Just _ -> return expr
+  where
+    -- Find the function called and the arguments
+    (fexpr, args) = collectArgs expr
+    Var f = fexpr
+    -- Change any arguments that have a function type, but are not simple yet
+    -- (ie, a variable or application). This means to create a new function
+    -- for map (\f -> ...) b, but not for map (foo a) b.
+    --
+    -- We could use is_applicable here instead of is_fun, but I think
+    -- arguments to functions could only have forall typing when existential
+    -- typing is enabled. Not sure, though.
+    doarg arg | not (is_simple arg) && is_fun arg = do
+      -- Create a new top level binding that binds the argument. Its body will
+      -- be extended with lambda expressions, to take any free variables used
+      -- by the argument expression.
+      let free_vars = VarSet.varSetElems $ CoreFVs.exprFreeVars arg
+      let body = MkCore.mkCoreLams free_vars arg
+      id <- mkBinderFor body "fun"
+      Trans.lift $ addGlobalBind id body
+      -- Replace the argument with a reference to the new function, applied to
+      -- all vars it uses.
+      change $ MkCore.mkCoreApps (Var id) (map Var free_vars)
+    -- Leave all other arguments untouched
+    doarg arg = return arg
 
--- TODO: introduce top level let if needed?
+-- Leave all other expressions unchanged
+funextract expr = return expr
+-- Perform this transform everywhere
+funextracttop = everywhere ("funextract", funextract)
 
 --------------------------------
 -- End of transformations
@@ -325,17 +460,18 @@ typeproptop = everywhere ("typeprop", typeprop)
 
 
 -- What transforms to run?
-transforms = [typeproptop, etatop, betatop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinefuntop, appsimpltop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinenonreptop, appsimpltop]
 
 -- Turns the given bind into VHDL
-normalizeModule :: 
-  UniqSupply.UniqSupply -- ^ A UniqSupply we can use
+normalizeModule ::
+  HscTypes.HscEnv
+  -> UniqSupply.UniqSupply -- ^ A UniqSupply we can use
   -> [(CoreBndr, CoreExpr)]  -- ^ All bindings we know (i.e., in the current module)
   -> [CoreBndr]  -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
   -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
-  -> [(CoreBndr, CoreExpr)] -- ^ The resulting VHDL
+  -> ([(CoreBndr, CoreExpr)], TypeState) -- ^ The resulting VHDL
 
-normalizeModule uniqsupply bindings generate_for statefuls = runTransformSession uniqsupply $ do
+normalizeModule env uniqsupply bindings generate_for statefuls = runTransformSession env uniqsupply $ do
   -- Put all the bindings in this module in the tsBindings map
   putA tsBindings (Map.fromList bindings)
   -- (Recursively) normalize each of the requested bindings
@@ -344,38 +480,54 @@ normalizeModule uniqsupply bindings generate_for statefuls = runTransformSession
   bindings_map <- getA tsBindings
   let bindings = Map.assocs bindings_map
   normalized_bindings <- getA tsNormalized
+  typestate <- getA tsType
   -- But return only the normalized bindings
-  return $ filter ((flip VarSet.elemVarSet normalized_bindings) . fst) bindings
+  return $ (filter ((flip VarSet.elemVarSet normalized_bindings) . fst) bindings, typestate)
 
 normalizeBind :: CoreBndr -> TransformSession ()
-normalizeBind bndr = do
-  normalized_funcs <- getA tsNormalized
-  -- See if this function was normalized already
-  if VarSet.elemVarSet bndr normalized_funcs
-    then
-      -- Yup, don't do it again
-      return ()
-    else do
-      -- Nope, note that it has been and do it.
-      modA tsNormalized (flip VarSet.extendVarSet bndr)
-      expr_maybe <- getGlobalBind bndr
-      case expr_maybe of 
-        Just expr -> do
-          -- Normalize this expression
-          trace ("Transforming " ++ (show bndr) ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr ) ++ "\n") $ return ()
-          expr' <- dotransforms transforms expr
-          trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
-          -- And store the normalized version in the session
-          modA tsBindings (Map.insert bndr expr')
-          -- Find all vars used with a function type. All of these should be global
-          -- binders (i.e., functions used), since any local binders with a function
-          -- type should have been inlined already.
-          let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> (Type.isFunTy . snd . Type.splitForAllTys . Id.idType) v) expr'
-          let used_funcs = VarSet.varSetElems used_funcs_set
-          -- Process each of the used functions recursively
-          mapM normalizeBind used_funcs
-          return ()
-        -- We don't have a value for this binder, let's assume this is a builtin
-        -- function. This might need some extra checking and a nice error
-        -- message).
-        Nothing -> return ()
+normalizeBind bndr =
+  -- Don't normalize global variables, these should be either builtin
+  -- functions or data constructors.
+  Monad.when (Var.isLocalId bndr) $ do
+    -- Skip binders that have a polymorphic type, since it's impossible to
+    -- create polymorphic hardware.
+    if is_poly (Var bndr)
+      then
+        -- This should really only happen at the top level... TODO: Give
+        -- a different error if this happens down in the recursion.
+        error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
+      else do
+        normalized_funcs <- getA tsNormalized
+        -- See if this function was normalized already
+        if VarSet.elemVarSet bndr normalized_funcs
+          then
+            -- Yup, don't do it again
+            return ()
+          else do
+            -- Nope, note that it has been and do it.
+            modA tsNormalized (flip VarSet.extendVarSet bndr)
+            expr_maybe <- getGlobalBind bndr
+            case expr_maybe of 
+              Just expr -> do
+                -- Introduce an empty Let at the top level, so there will always be
+                -- a let in the expression (none of the transformations will remove
+                -- the last let).
+                let expr' = Let (Rec []) expr
+                -- Normalize this expression
+                trace ("Transforming " ++ (show bndr) ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
+                expr' <- dotransforms transforms expr'
+                trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+                -- And store the normalized version in the session
+                modA tsBindings (Map.insert bndr expr')
+                -- Find all vars used with a function type. All of these should be global
+                -- binders (i.e., functions used), since any local binders with a function
+                -- type should have been inlined already.
+                bndrs <- getGlobalBinders
+                let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> not (Id.isDictId v) && v `elem` bndrs) expr'
+                let used_funcs = VarSet.varSetElems used_funcs_set
+                -- Process each of the used functions recursively
+                mapM normalizeBind used_funcs
+                return ()
+              -- We don't have a value for this binder. This really shouldn't
+              -- happen for local id's...
+              Nothing -> error $ "\nNormalize.normalizeBind: No value found for binder " ++ pprString bndr ++ "? This should not happen!"