Make subeverywhere complain for unknown expressions.
[matthijs/master-project/cλash.git] / NormalizeTools.hs
1 {-# LANGUAGE PackageImports #-}
2 -- 
3 -- This module provides functions for program transformations.
4 --
5 module NormalizeTools where
6 -- Standard modules
7 import Debug.Trace
8 import qualified List
9 import qualified Data.Monoid as Monoid
10 import qualified Control.Arrow as Arrow
11 import qualified Control.Monad as Monad
12 import qualified Control.Monad.Trans.State as State
13 import qualified Control.Monad.Trans.Writer as Writer
14 import qualified "transformers" Control.Monad.Trans as Trans
15 import qualified Data.Map as Map
16 import Data.Accessor
17
18 -- GHC API
19 import CoreSyn
20 import qualified UniqSupply
21 import qualified Unique
22 import qualified OccName
23 import qualified Name
24 import qualified Var
25 import qualified SrcLoc
26 import qualified Type
27 import qualified IdInfo
28 import qualified CoreUtils
29 import qualified CoreSubst
30 import qualified VarSet
31 import Outputable ( showSDoc, ppr, nest )
32
33 -- Local imports
34 import NormalizeTypes
35
36 -- Create a new internal var with the given name and type. A Unique is
37 -- appended to the given name, to ensure uniqueness (not strictly neccesary,
38 -- since the Unique is also stored in the name, but this ensures variable
39 -- names are unique in the output).
40 mkInternalVar :: String -> Type.Type -> TransformMonad Var.Var
41 mkInternalVar str ty = do
42   uniq <- mkUnique
43   let occname = OccName.mkVarOcc (str ++ show uniq)
44   let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
45   return $ Var.mkLocalIdVar name ty IdInfo.vanillaIdInfo
46
47 -- Create a new type variable with the given name and kind. A Unique is
48 -- appended to the given name, to ensure uniqueness (not strictly neccesary,
49 -- since the Unique is also stored in the name, but this ensures variable
50 -- names are unique in the output).
51 mkTypeVar :: String -> Type.Kind -> TransformMonad Var.Var
52 mkTypeVar str kind = do
53   uniq <- mkUnique
54   let occname = OccName.mkVarOcc (str ++ show uniq)
55   let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
56   return $ Var.mkTyVar name kind
57
58 -- Creates a binder for the given expression with the given name. This
59 -- works for both value and type level expressions, so it can return a Var or
60 -- TyVar (which is just an alias for Var).
61 mkBinderFor :: CoreExpr -> String -> TransformMonad Var.Var
62 mkBinderFor (Type ty) string = mkTypeVar string (Type.typeKind ty)
63 mkBinderFor expr string = mkInternalVar string (CoreUtils.exprType expr)
64
65 -- Creates a reference to the given variable. This works for both a normal
66 -- variable as well as a type variable
67 mkReferenceTo :: Var.Var -> CoreExpr
68 mkReferenceTo var | Var.isTyVar var = (Type $ Type.mkTyVarTy var)
69                   | otherwise       = (Var var)
70
71 cloneVar :: Var.Var -> TransformMonad Var.Var
72 cloneVar v = do
73   uniq <- mkUnique
74   -- Swap out the unique, and reset the IdInfo (I'm not 100% sure what it
75   -- contains, but vannillaIdInfo is always correct, since it means "no info").
76   return $ Var.lazySetVarIdInfo (Var.setVarUnique v uniq) IdInfo.vanillaIdInfo
77
78 -- Creates a new function with the same name as the given binder (but with a
79 -- new unique) and with the given function body. Returns the new binder for
80 -- this function.
81 mkFunction :: CoreBndr -> CoreExpr -> TransformMonad CoreBndr
82 mkFunction bndr body = do
83   let ty = CoreUtils.exprType body
84   id <- cloneVar bndr
85   let newid = Var.setVarType id ty
86   Trans.lift $ addGlobalBind newid body
87   return newid
88
89 -- Apply the given transformation to all expressions in the given expression,
90 -- including the expression itself.
91 everywhere :: (String, Transform) -> Transform
92 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
93
94 -- Apply the first transformation, followed by the second transformation, and
95 -- keep applying both for as long as expression still changes.
96 applyboth :: Transform -> (String, Transform) -> Transform
97 applyboth first (name, second) expr  = do
98   -- Apply the first
99   expr' <- first expr
100   -- Apply the second
101   (expr'', changed) <- Writer.listen $ second expr'
102   if Monoid.getAny $
103   --      trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
104         changed 
105     then 
106 --      trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
107  --     trace ("Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
108       applyboth first (name, second) $
109         expr'' 
110     else 
111     --  trace ("No changes") $
112       return expr''
113
114 -- Apply the given transformation to all direct subexpressions (only), not the
115 -- expression itself.
116 subeverywhere :: Transform -> Transform
117 subeverywhere trans (App a b) = do
118   a' <- trans a
119   b' <- trans b
120   return $ App a' b'
121
122 subeverywhere trans (Let (Rec binds) expr) = do
123   expr' <- trans expr
124   binds' <- mapM transbind binds
125   return $ Let (Rec binds') expr'
126   where
127     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
128     transbind (b, e) = do
129       e' <- trans e
130       return (b, e')
131
132 subeverywhere trans (Lam x expr) = do
133   expr' <- trans expr
134   return $ Lam x expr'
135
136 subeverywhere trans (Case scrut b t alts) = do
137   scrut' <- trans scrut
138   alts' <- mapM transalt alts
139   return $ Case scrut' b t alts'
140   where
141     transalt :: CoreAlt -> TransformMonad CoreAlt
142     transalt (con, binders, expr) = do
143       expr' <- trans expr
144       return (con, binders, expr')
145
146 subeverywhere trans (Var x) = return $ Var x
147 subeverywhere trans (Lit x) = return $ Lit x
148 subeverywhere trans (Type x) = return $ Type x
149
150 subeverywhere trans expr = error $ "NormalizeTools.subeverywhere Unsupported expression: " ++ show expr
151
152 -- Apply the given transformation to all expressions, except for direct
153 -- arguments of an application
154 notappargs :: (String, Transform) -> Transform
155 notappargs trans = applyboth (subnotappargs trans) trans
156
157 -- Apply the given transformation to all (direct and indirect) subexpressions
158 -- (but not the expression itself), except for direct arguments of an
159 -- application
160 subnotappargs :: (String, Transform) -> Transform
161 subnotappargs trans (App a b) = do
162   a' <- subnotappargs trans a
163   b' <- subnotappargs trans b
164   return $ App a' b'
165
166 -- Let subeverywhere handle all other expressions
167 subnotappargs trans expr = subeverywhere (notappargs trans) expr
168
169 -- Runs each of the transforms repeatedly inside the State monad.
170 dotransforms :: [Transform] -> CoreExpr -> TransformSession CoreExpr
171 dotransforms transs expr = do
172   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
173   if Monoid.getAny changed then dotransforms transs expr' else return expr'
174
175 -- Inline all let bindings that satisfy the given condition
176 inlinebind :: ((CoreBndr, CoreExpr) -> Bool) -> Transform
177 inlinebind condition (Let (Rec binds) expr) | not $ null replace =
178     change newexpr
179   where 
180     -- Find all simple bindings
181     (replace, others) = List.partition condition binds
182     -- Substitute the to be replaced binders with their expression
183     newexpr = substitute replace (Let (Rec others) expr)
184 -- Leave all other expressions unchanged
185 inlinebind _ expr = return expr
186
187 -- Sets the changed flag in the TransformMonad, to signify that some
188 -- transform has changed the result
189 setChanged :: TransformMonad ()
190 setChanged = Writer.tell (Monoid.Any True)
191
192 -- Sets the changed flag and returns the given value.
193 change :: a -> TransformMonad a
194 change val = do
195   setChanged
196   return val
197
198 -- Create a new Unique
199 mkUnique :: TransformMonad Unique.Unique
200 mkUnique = Trans.lift $ do
201     us <- getA tsUniqSupply 
202     let (us', us'') = UniqSupply.splitUniqSupply us
203     putA tsUniqSupply us'
204     return $ UniqSupply.uniqFromSupply us''
205
206 -- Replace each of the binders given with the coresponding expressions in the
207 -- given expression.
208 substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
209 substitute [] expr = expr
210 -- Apply one substitution on the expression, but also on any remaining
211 -- substitutions. This seems to be the only way to handle substitutions like
212 -- [(b, c), (a, b)]. This means we reuse a substitution, which is not allowed
213 -- according to CoreSubst documentation (but it doesn't seem to be a problem).
214 -- TODO: Find out how this works, exactly.
215 substitute ((b, e):subss) expr = substitute subss' expr'
216   where 
217     -- Create the Subst
218     subs = (CoreSubst.extendSubst CoreSubst.emptySubst b e)
219     -- Apply this substitution to the main expression
220     expr' = CoreSubst.substExpr subs expr
221     -- Apply this substitution on all the expressions in the remaining
222     -- substitutions
223     subss' = map (Arrow.second (CoreSubst.substExpr subs)) subss
224
225 -- Run a given TransformSession. Used mostly to setup the right calls and
226 -- an initial state.
227 runTransformSession :: UniqSupply.UniqSupply -> TransformSession a -> a
228 runTransformSession uniqSupply session = State.evalState session initState
229                        where initState = TransformState uniqSupply Map.empty VarSet.emptyVarSet