Add a substitute helper function.
[matthijs/master-project/cλash.git] / NormalizeTools.hs
1 {-# LANGUAGE PackageImports #-}
2 -- 
3 -- This module provides functions for program transformations.
4 --
5 module NormalizeTools where
6 -- Standard modules
7 import Debug.Trace
8 import qualified Data.Monoid as Monoid
9 import qualified Control.Monad as Monad
10 import qualified Control.Monad.Trans.State as State
11 import qualified Control.Monad.Trans.Writer as Writer
12 import qualified "transformers" Control.Monad.Trans as Trans
13 import Data.Accessor
14
15 -- GHC API
16 import CoreSyn
17 import qualified UniqSupply
18 import qualified Unique
19 import qualified OccName
20 import qualified Name
21 import qualified Var
22 import qualified SrcLoc
23 import qualified Type
24 import qualified IdInfo
25 import qualified CoreUtils
26 import qualified CoreSubst
27 import Outputable ( showSDoc, ppr, nest )
28
29 -- Local imports
30 import NormalizeTypes
31
32 -- Create a new internal var with the given name and type. A Unique is
33 -- appended to the given name, to ensure uniqueness (not strictly neccesary,
34 -- since the Unique is also stored in the name, but this ensures variable
35 -- names are unique in the output).
36 mkInternalVar :: String -> Type.Type -> TransformMonad Var.Var
37 mkInternalVar str ty = do
38   uniq <- mkUnique
39   let occname = OccName.mkVarOcc (str ++ show uniq)
40   let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
41   return $ Var.mkLocalIdVar name ty IdInfo.vanillaIdInfo
42
43 -- Apply the given transformation to all expressions in the given expression,
44 -- including the expression itself.
45 everywhere :: (String, Transform) -> Transform
46 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
47
48 -- Apply the first transformation, followed by the second transformation, and
49 -- keep applying both for as long as expression still changes.
50 applyboth :: Transform -> (String, Transform) -> Transform
51 applyboth first (name, second) expr  = do
52   -- Apply the first
53   expr' <- first expr
54   -- Apply the second
55   (expr'', changed) <- Writer.listen $ second expr'
56   if Monoid.getAny changed 
57     then 
58       trace ("Transform " ++ name ++ " changed from:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n" ++ "\nTo:\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
59       applyboth first (name, second) expr'' 
60     else 
61       return expr''
62
63 -- Apply the given transformation to all direct subexpressions (only), not the
64 -- expression itself.
65 subeverywhere :: Transform -> Transform
66 subeverywhere trans (App a b) = do
67   a' <- trans a
68   b' <- trans b
69   return $ App a' b'
70
71 subeverywhere trans (Let (Rec binds) expr) = do
72   expr' <- trans expr
73   binds' <- mapM transbind binds
74   return $ Let (Rec binds') expr'
75   where
76     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
77     transbind (b, e) = do
78       e' <- trans e
79       return (b, e')
80
81 subeverywhere trans (Lam x expr) = do
82   expr' <- trans expr
83   return $ Lam x expr'
84
85 subeverywhere trans (Case scrut b t alts) = do
86   scrut' <- trans scrut
87   alts' <- mapM transalt alts
88   return $ Case scrut' b t alts'
89   where
90     transalt :: CoreAlt -> TransformMonad CoreAlt
91     transalt (con, binders, expr) = do
92       expr' <- trans expr
93       return (con, binders, expr')
94       
95
96 subeverywhere trans expr = return expr
97
98 -- Apply the given transformation to all expressions, except for every first
99 -- argument of an application.
100 notapplied :: (String, Transform) -> Transform
101 notapplied trans = applyboth (subnotapplied trans) trans
102
103 -- Apply the given transformation to all (direct and indirect) subexpressions
104 -- (but not the expression itself), except for the first argument of an
105 -- applicfirst argument of an application
106 subnotapplied :: (String, Transform) -> Transform
107 subnotapplied trans (App a b) = do
108   a' <- subnotapplied trans a
109   b' <- notapplied trans b
110   return $ App a' b'
111
112 -- Let subeverywhere handle all other expressions
113 subnotapplied trans expr = subeverywhere (notapplied trans) expr
114
115 -- Run the given transforms over the given expression
116 dotransforms :: [Transform] -> UniqSupply.UniqSupply -> CoreExpr -> CoreExpr
117 dotransforms transs uniqSupply = (flip State.evalState initState) . (dotransforms' transs)
118                        where initState = TransformState uniqSupply
119
120 -- Runs each of the transforms repeatedly inside the State monad.
121 dotransforms' :: [Transform] -> CoreExpr -> State.State TransformState CoreExpr
122 dotransforms' transs expr = do
123   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
124   if Monoid.getAny changed then dotransforms' transs expr' else return expr'
125
126 -- Sets the changed flag in the TransformMonad, to signify that some
127 -- transform has changed the result
128 setChanged :: TransformMonad ()
129 setChanged = Writer.tell (Monoid.Any True)
130
131 -- Sets the changed flag and returns the given value.
132 change :: a -> TransformMonad a
133 change val = do
134   setChanged
135   return val
136
137 -- Create a new Unique
138 mkUnique :: TransformMonad Unique.Unique
139 mkUnique = Trans.lift $ do
140     us <- getA tsUniqSupply 
141     let (us', us'') = UniqSupply.splitUniqSupply us
142     putA tsUniqSupply us'
143     return $ UniqSupply.uniqFromSupply us''
144
145 -- Replace each of the binders given with the coresponding expressions in the
146 -- given expression.
147 substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
148 substitute replace expr = CoreSubst.substExpr subs expr
149     where subs = foldl (\s (b, e) -> CoreSubst.extendIdSubst s b e) CoreSubst.emptySubst replace