Merge branch 'cλash' of http://git.stderr.nl/matthijs/projects/master-project
[matthijs/master-project/cλash.git] / NormalizeTools.hs
index 5ea3a7db8ab852fce0a7dd8529573e718ff2e7eb..1290fd85bd8b19d7aa93a90f59b81c779061c796 100644 (file)
@@ -7,6 +7,7 @@ module NormalizeTools where
 import Debug.Trace
 import qualified List
 import qualified Data.Monoid as Monoid
 import Debug.Trace
 import qualified List
 import qualified Data.Monoid as Monoid
+import qualified Data.Either as Either
 import qualified Control.Arrow as Arrow
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.State as State
 import qualified Control.Arrow as Arrow
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.State as State
@@ -29,11 +30,13 @@ import qualified IdInfo
 import qualified CoreUtils
 import qualified CoreSubst
 import qualified VarSet
 import qualified CoreUtils
 import qualified CoreSubst
 import qualified VarSet
+import qualified HscTypes
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import NormalizeTypes
 import Pretty
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import NormalizeTypes
 import Pretty
+import VHDLTypes
 import qualified VHDLTools
 
 -- Create a new internal var with the given name and type. A Unique is
 import qualified VHDLTools
 
 -- Create a new internal var with the given name and type. A Unique is
@@ -103,15 +106,15 @@ applyboth first (name, second) expr  = do
   -- Apply the second
   (expr'', changed) <- Writer.listen $ second expr'
   if Monoid.getAny $
   -- Apply the second
   (expr'', changed) <- Writer.listen $ second expr'
   if Monoid.getAny $
-  --      trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
+--        trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
         changed 
     then 
 --      trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
         changed 
     then 
 --      trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
- --     trace ("Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
+--      trace ("Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
       applyboth first (name, second) $
         expr'' 
     else 
       applyboth first (name, second) $
         expr'' 
     else 
-    --  trace ("No changes") $
+--      trace ("No changes") $
       return expr''
 
 -- Apply the given transformation to all direct subexpressions (only), not the
       return expr''
 
 -- Apply the given transformation to all direct subexpressions (only), not the
@@ -185,14 +188,23 @@ dotransforms transs expr = do
   if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
   if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
-inlinebind :: ((CoreBndr, CoreExpr) -> Bool) -> Transform
-inlinebind condition (Let (Rec binds) expr) | not $ null replace =
-    change newexpr
+inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
+inlinebind condition expr@(Let (Rec binds) res) = do
+    -- Find all bindings that adhere to the condition
+    res_eithers <- mapM docond binds
+    case Either.partitionEithers res_eithers of
+      -- No replaces? No change
+      ([], _) -> return expr
+      (replace, others) -> do
+        -- Substitute the to be replaced binders with their expression
+        let newexpr = substitute replace (Let (Rec others) res)
+        change newexpr
   where 
   where 
-    -- Find all simple bindings
-    (replace, others) = List.partition condition binds
-    -- Substitute the to be replaced binders with their expression
-    newexpr = substitute replace (Let (Rec others) expr)
+    docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
+    docond b = do
+      res <- condition b
+      return $ case res of True -> Left b; False -> Right b
+
 -- Leave all other expressions unchanged
 inlinebind _ expr = return expr
 
 -- Leave all other expressions unchanged
 inlinebind _ expr = return expr
 
@@ -236,10 +248,19 @@ substitute ((b, e):subss) expr = substitute subss' expr'
 
 -- Run a given TransformSession. Used mostly to setup the right calls and
 -- an initial state.
 
 -- Run a given TransformSession. Used mostly to setup the right calls and
 -- an initial state.
-runTransformSession :: UniqSupply.UniqSupply -> TransformSession a -> a
-runTransformSession uniqSupply session = State.evalState session (emptyTransformState uniqSupply)
+runTransformSession :: HscTypes.HscEnv -> UniqSupply.UniqSupply -> TransformSession a -> a
+runTransformSession env uniqSupply session = State.evalState session emptyTransformState
+  where
+    emptyTypeState = TypeState Map.empty [] Map.empty Map.empty env
+    emptyTransformState = TransformState uniqSupply Map.empty VarSet.emptyVarSet emptyTypeState
 
 -- Is the given expression representable at runtime, based on the type?
 isRepr :: CoreSyn.CoreExpr -> TransformMonad Bool
 isRepr (Type ty) = return False
 isRepr expr = Trans.lift $ MonadState.lift tsType $ VHDLTools.isReprType (CoreUtils.exprType expr)
 
 -- Is the given expression representable at runtime, based on the type?
 isRepr :: CoreSyn.CoreExpr -> TransformMonad Bool
 isRepr (Type ty) = return False
 isRepr expr = Trans.lift $ MonadState.lift tsType $ VHDLTools.isReprType (CoreUtils.exprType expr)
+
+is_local_var :: CoreSyn.CoreExpr -> TransformSession Bool
+is_local_var (CoreSyn.Var v) = do
+  bndrs <- getGlobalBinders
+  return $ not $ v `elem` bndrs
+is_local_var _ = return False