Moved to new GHC API (6.11). Also use vhdl package for the VHDL AST
[matthijs/master-project/cλash.git] / Normalize.hs
index 2ecf2fa102ba5b56cdc5b9457ac8c792666c06dd..12356e23c7b9af33a6a0c8989fa31efc27172ab6 100644 (file)
@@ -21,17 +21,21 @@ import CoreSyn
 import qualified UniqSupply
 import qualified CoreUtils
 import qualified Type
+import qualified TcType
 import qualified Id
 import qualified Var
 import qualified VarSet
+import qualified NameSet
 import qualified CoreFVs
 import qualified CoreUtils
 import qualified MkCore
+import qualified HscTypes
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import NormalizeTypes
 import NormalizeTools
+import VHDLTypes
 import CoreTools
 import Pretty
 
@@ -97,14 +101,21 @@ letrectop = everywhere ("letrec", letrec)
 -- let simplification
 --------------------------------
 letsimpl, letsimpltop :: Transform
--- Don't simplifiy lets that are already simple
-letsimpl expr@(Let _ (Var _)) = return expr
 -- Put the "in ..." value of a let in its own binding, but not when the
 -- expression is applicable (to prevent loops with inlinefun).
-letsimpl (Let (Rec binds) expr) | not $ is_applicable expr = do
-  id <- mkInternalVar "foo" (CoreUtils.exprType expr)
-  let bind = (id, expr)
-  change $ Let (Rec (bind:binds)) (Var id)
+letsimpl expr@(Let (Rec binds) res) | not $ is_applicable expr = do
+  local_var <- Trans.lift $ is_local_var res
+  if not local_var
+    then do
+      -- If the result is not a local var already (to prevent loops with
+      -- ourselves), extract it.
+      id <- mkInternalVar "foo" (CoreUtils.exprType res)
+      let bind = (id, res)
+      change $ Let (Rec (bind:binds)) (Var id)
+    else
+      -- If the result is already a local var, don't extract it.
+      return expr
+
 -- Leave all other expressions unchanged
 letsimpl expr = return expr
 -- Perform this transform everywhere
@@ -140,7 +151,7 @@ letflattop = everywhere ("letflat", letflat)
 --------------------------------
 -- Remove a = b bindings from let expressions everywhere
 letremovetop :: Transform
-letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> case e of (Var v) | not $ Id.isDataConWorkId v -> return True; otherwise -> return False))
+letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
 
 --------------------------------
 -- Function inlining
@@ -194,7 +205,7 @@ casewild expr@(Case scrut b ty alts) = do
   if null bindings || length alts == 1 && length bindings == 1 then return expr else change newlet 
   where
   -- Generate a single wild binder, since they are all the same
-  wild = Id.mkWildId
+  wild = MkCore.mkWildBinder
   -- Wilden the binders of one alt, producing a list of bindings as a
   -- sideeffect.
   doalt :: CoreAlt -> TransformMonad ([(CoreBndr, CoreExpr)], CoreAlt)
@@ -208,15 +219,17 @@ casewild expr@(Case scrut b ty alts) = do
     return (bindings, newalt)
     where
       -- Make all binders wild
-      wildbndrs = map (\bndr -> Id.mkWildId (Id.idType bndr)) bndrs
+      wildbndrs = map (\bndr -> MkCore.mkWildBinder (Id.idType bndr)) bndrs
+      -- A set of all the binders that are used by the expression
+      free_vars = CoreFVs.exprSomeFreeVars (`elem` bndrs) expr
       -- Creates a case statement to retrieve the ith element from the scrutinee
       -- and binds that to b.
       mkextracts :: CoreBndr -> Int -> TransformMonad (Maybe (CoreBndr, CoreExpr))
       mkextracts b i =
-        -- TODO: Use free variables instead of is_wild. is_wild is a hack.
-        if is_wild b || Type.isFunTy (Id.idType b) 
-          -- Don't create extra bindings for binders that are already wild, or
-          -- for binders that bind function types (to prevent loops with
+        if not (VarSet.elemVarSet b free_vars) || Type.isFunTy (Id.idType b) 
+          -- Don't create extra bindings for binders that are already wild
+          -- (e.g. not in the free variables of expr, so unused), or for
+          -- binders that bind function types (to prevent loops with
           -- inlinefun).
           then return Nothing
           else do
@@ -289,14 +302,13 @@ caseremovetop = everywhere ("caseremove", caseremove)
 --------------------------------
 -- Make sure that all arguments of a representable type are simple variables.
 appsimpl, appsimpltop :: Transform
--- Don't simplify arguments that are already simple.
-appsimpl expr@(App f (Var v)) = return expr
 -- Simplify all representable arguments. Do this by introducing a new Let
 -- that binds the argument and passing the new binder in the application.
 appsimpl expr@(App f arg) = do
   -- Check runtime representability
   repr <- isRepr arg
-  if repr
+  local_var <- Trans.lift $ is_local_var arg
+  if repr && not local_var
     then do -- Extract representable arguments
       id <- mkInternalVar "arg" (CoreUtils.exprType arg)
       change $ Let (Rec [(id, arg)]) (App f (Var id))
@@ -451,14 +463,15 @@ funextracttop = everywhere ("funextract", funextract)
 transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinenonreptop, appsimpltop]
 
 -- Turns the given bind into VHDL
-normalizeModule :: 
-  UniqSupply.UniqSupply -- ^ A UniqSupply we can use
+normalizeModule ::
+  HscTypes.HscEnv
+  -> UniqSupply.UniqSupply -- ^ A UniqSupply we can use
   -> [(CoreBndr, CoreExpr)]  -- ^ All bindings we know (i.e., in the current module)
   -> [CoreBndr]  -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
   -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
-  -> [(CoreBndr, CoreExpr)] -- ^ The resulting VHDL
+  -> ([(CoreBndr, CoreExpr)], TypeState) -- ^ The resulting VHDL
 
-normalizeModule uniqsupply bindings generate_for statefuls = runTransformSession uniqsupply $ do
+normalizeModule env uniqsupply bindings generate_for statefuls = runTransformSession env uniqsupply $ do
   -- Put all the bindings in this module in the tsBindings map
   putA tsBindings (Map.fromList bindings)
   -- (Recursively) normalize each of the requested bindings
@@ -467,14 +480,15 @@ normalizeModule uniqsupply bindings generate_for statefuls = runTransformSession
   bindings_map <- getA tsBindings
   let bindings = Map.assocs bindings_map
   normalized_bindings <- getA tsNormalized
+  typestate <- getA tsType
   -- But return only the normalized bindings
-  return $ filter ((flip VarSet.elemVarSet normalized_bindings) . fst) bindings
+  return $ (filter ((flip VarSet.elemVarSet normalized_bindings) . fst) bindings, typestate)
 
 normalizeBind :: CoreBndr -> TransformSession ()
 normalizeBind bndr =
   -- Don't normalize global variables, these should be either builtin
   -- functions or data constructors.
-  Monad.when (Var.isLocalIdVar bndr) $ do
+  Monad.when (Var.isLocalId bndr) $ do
     -- Skip binders that have a polymorphic type, since it's impossible to
     -- create polymorphic hardware.
     if is_poly (Var bndr)
@@ -508,7 +522,8 @@ normalizeBind bndr =
                 -- Find all vars used with a function type. All of these should be global
                 -- binders (i.e., functions used), since any local binders with a function
                 -- type should have been inlined already.
-                let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> (Type.isFunTy . snd . Type.splitForAllTys . Id.idType) v) expr'
+                bndrs <- getGlobalBinders
+                let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> not (Id.isDictId v) && v `elem` bndrs) expr'
                 let used_funcs = VarSet.varSetElems used_funcs_set
                 -- Process each of the used functions recursively
                 mapM normalizeBind used_funcs