Make substitute work for type variables as well.
[matthijs/master-project/cλash.git] / NormalizeTools.hs
index 6699101d33bc18a3a90a8870ce18d095cd72b129..923b54567767e68d7ccc68bfda2c0db9c70114f2 100644 (file)
@@ -11,6 +11,7 @@ import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.State as State
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad.Trans.State as State
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
+import qualified Data.Map as Map
 import Data.Accessor
 
 -- GHC API
 import Data.Accessor
 
 -- GHC API
@@ -25,6 +26,7 @@ import qualified Type
 import qualified IdInfo
 import qualified CoreUtils
 import qualified CoreSubst
 import qualified IdInfo
 import qualified CoreUtils
 import qualified CoreSubst
+import qualified VarSet
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
@@ -41,6 +43,13 @@ mkInternalVar str ty = do
   let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
   return $ Var.mkLocalIdVar name ty IdInfo.vanillaIdInfo
 
   let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
   return $ Var.mkLocalIdVar name ty IdInfo.vanillaIdInfo
 
+cloneVar :: Var.Var -> TransformMonad Var.Var
+cloneVar v = do
+  uniq <- mkUnique
+  -- Swap out the unique, and reset the IdInfo (I'm not 100% sure what it
+  -- contains, but vannillaIdInfo is always correct, since it means "no info").
+  return $ Var.lazySetVarIdInfo (Var.setVarUnique v uniq) IdInfo.vanillaIdInfo
+
 -- Apply the given transformation to all expressions in the given expression,
 -- including the expression itself.
 everywhere :: (String, Transform) -> Transform
 -- Apply the given transformation to all expressions in the given expression,
 -- including the expression itself.
 everywhere :: (String, Transform) -> Transform
@@ -54,11 +63,16 @@ applyboth first (name, second) expr  = do
   expr' <- first expr
   -- Apply the second
   (expr'', changed) <- Writer.listen $ second expr'
   expr' <- first expr
   -- Apply the second
   (expr'', changed) <- Writer.listen $ second expr'
-  if Monoid.getAny changed 
+  if Monoid.getAny $
+  --      trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
+        changed 
     then 
     then 
-      trace ("Transform " ++ name ++ " changed from:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n" ++ "\nTo:\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
-      applyboth first (name, second) expr'' 
+--      trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
+ --     trace ("Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
+      applyboth first (name, second) $
+        expr'' 
     else 
     else 
+    --  trace ("No changes") $
       return expr''
 
 -- Apply the given transformation to all direct subexpressions (only), not the
       return expr''
 
 -- Apply the given transformation to all direct subexpressions (only), not the
@@ -113,16 +127,11 @@ subnotapplied trans (App a b) = do
 -- Let subeverywhere handle all other expressions
 subnotapplied trans expr = subeverywhere (notapplied trans) expr
 
 -- Let subeverywhere handle all other expressions
 subnotapplied trans expr = subeverywhere (notapplied trans) expr
 
--- Run the given transforms over the given expression
-dotransforms :: [Transform] -> UniqSupply.UniqSupply -> CoreExpr -> CoreExpr
-dotransforms transs uniqSupply = (flip State.evalState initState) . (dotransforms' transs)
-                       where initState = TransformState uniqSupply
-
 -- Runs each of the transforms repeatedly inside the State monad.
 -- Runs each of the transforms repeatedly inside the State monad.
-dotransforms' :: [Transform] -> CoreExpr -> State.State TransformState CoreExpr
-dotransforms' transs expr = do
+dotransforms :: [Transform] -> CoreExpr -> TransformSession CoreExpr
+dotransforms transs expr = do
   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
-  if Monoid.getAny changed then dotransforms' transs expr' else return expr'
+  if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
 inlinebind :: ((CoreBndr, CoreExpr) -> Bool) -> Transform
 
 -- Inline all let bindings that satisfy the given condition
 inlinebind :: ((CoreBndr, CoreExpr) -> Bool) -> Transform
@@ -159,4 +168,10 @@ mkUnique = Trans.lift $ do
 -- given expression.
 substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
 substitute replace expr = CoreSubst.substExpr subs expr
 -- given expression.
 substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
 substitute replace expr = CoreSubst.substExpr subs expr
-    where subs = foldl (\s (b, e) -> CoreSubst.extendIdSubst s b e) CoreSubst.emptySubst replace
+    where subs = foldl (\s (b, e) -> CoreSubst.extendSubst s b e) CoreSubst.emptySubst replace
+
+-- Run a given TransformSession. Used mostly to setup the right calls and
+-- an initial state.
+runTransformSession :: UniqSupply.UniqSupply -> TransformSession a -> a
+runTransformSession uniqSupply session = State.evalState session initState
+                       where initState = TransformState uniqSupply Map.empty VarSet.emptyVarSet