Make beta reduction of Case expressions work for type arguments.
[matthijs/master-project/cλash.git] / NormalizeTools.hs
1 {-# LANGUAGE PackageImports #-}
2 -- 
3 -- This module provides functions for program transformations.
4 --
5 module NormalizeTools where
6 -- Standard modules
7 import Debug.Trace
8 import qualified List
9 import qualified Data.Monoid as Monoid
10 import qualified Control.Monad as Monad
11 import qualified Control.Monad.Trans.State as State
12 import qualified Control.Monad.Trans.Writer as Writer
13 import qualified "transformers" Control.Monad.Trans as Trans
14 import qualified Data.Map as Map
15 import Data.Accessor
16
17 -- GHC API
18 import CoreSyn
19 import qualified UniqSupply
20 import qualified Unique
21 import qualified OccName
22 import qualified Name
23 import qualified Var
24 import qualified SrcLoc
25 import qualified Type
26 import qualified IdInfo
27 import qualified CoreUtils
28 import qualified CoreSubst
29 import qualified VarSet
30 import Outputable ( showSDoc, ppr, nest )
31
32 -- Local imports
33 import NormalizeTypes
34
35 -- Create a new internal var with the given name and type. A Unique is
36 -- appended to the given name, to ensure uniqueness (not strictly neccesary,
37 -- since the Unique is also stored in the name, but this ensures variable
38 -- names are unique in the output).
39 mkInternalVar :: String -> Type.Type -> TransformMonad Var.Var
40 mkInternalVar str ty = do
41   uniq <- mkUnique
42   let occname = OccName.mkVarOcc (str ++ show uniq)
43   let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
44   return $ Var.mkLocalIdVar name ty IdInfo.vanillaIdInfo
45
46 -- Create a new type variable with the given name and kind. A Unique is
47 -- appended to the given name, to ensure uniqueness (not strictly neccesary,
48 -- since the Unique is also stored in the name, but this ensures variable
49 -- names are unique in the output).
50 mkTypeVar :: String -> Type.Kind -> TransformMonad Var.Var
51 mkTypeVar str kind = do
52   uniq <- mkUnique
53   let occname = OccName.mkVarOcc (str ++ show uniq)
54   let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
55   return $ Var.mkTyVar name kind
56
57 -- Creates a binder for the given expression with the given name. This
58 -- works for both value and type level expressions, so it can return a Var or
59 -- TyVar (which is just an alias for Var).
60 mkBinderFor :: CoreExpr -> String -> TransformMonad Var.Var
61 mkBinderFor (Type ty) string = mkTypeVar string (Type.typeKind ty)
62 mkBinderFor expr string = mkInternalVar string (CoreUtils.exprType expr)
63
64 -- Creates a reference to the given variable. This works for both a normal
65 -- variable as well as a type variable
66 mkReferenceTo :: Var.Var -> CoreExpr
67 mkReferenceTo var | Var.isTyVar var = (Type $ Type.mkTyVarTy var)
68                   | otherwise       = (Var var)
69
70 cloneVar :: Var.Var -> TransformMonad Var.Var
71 cloneVar v = do
72   uniq <- mkUnique
73   -- Swap out the unique, and reset the IdInfo (I'm not 100% sure what it
74   -- contains, but vannillaIdInfo is always correct, since it means "no info").
75   return $ Var.lazySetVarIdInfo (Var.setVarUnique v uniq) IdInfo.vanillaIdInfo
76
77 -- Apply the given transformation to all expressions in the given expression,
78 -- including the expression itself.
79 everywhere :: (String, Transform) -> Transform
80 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
81
82 -- Apply the first transformation, followed by the second transformation, and
83 -- keep applying both for as long as expression still changes.
84 applyboth :: Transform -> (String, Transform) -> Transform
85 applyboth first (name, second) expr  = do
86   -- Apply the first
87   expr' <- first expr
88   -- Apply the second
89   (expr'', changed) <- Writer.listen $ second expr'
90   if Monoid.getAny $
91   --      trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
92         changed 
93     then 
94 --      trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
95  --     trace ("Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
96       applyboth first (name, second) $
97         expr'' 
98     else 
99     --  trace ("No changes") $
100       return expr''
101
102 -- Apply the given transformation to all direct subexpressions (only), not the
103 -- expression itself.
104 subeverywhere :: Transform -> Transform
105 subeverywhere trans (App a b) = do
106   a' <- trans a
107   b' <- trans b
108   return $ App a' b'
109
110 subeverywhere trans (Let (Rec binds) expr) = do
111   expr' <- trans expr
112   binds' <- mapM transbind binds
113   return $ Let (Rec binds') expr'
114   where
115     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
116     transbind (b, e) = do
117       e' <- trans e
118       return (b, e')
119
120 subeverywhere trans (Lam x expr) = do
121   expr' <- trans expr
122   return $ Lam x expr'
123
124 subeverywhere trans (Case scrut b t alts) = do
125   scrut' <- trans scrut
126   alts' <- mapM transalt alts
127   return $ Case scrut' b t alts'
128   where
129     transalt :: CoreAlt -> TransformMonad CoreAlt
130     transalt (con, binders, expr) = do
131       expr' <- trans expr
132       return (con, binders, expr')
133       
134
135 subeverywhere trans expr = return expr
136
137 -- Apply the given transformation to all expressions, except for every first
138 -- argument of an application.
139 notapplied :: (String, Transform) -> Transform
140 notapplied trans = applyboth (subnotapplied trans) trans
141
142 -- Apply the given transformation to all (direct and indirect) subexpressions
143 -- (but not the expression itself), except for the first argument of an
144 -- applicfirst argument of an application
145 subnotapplied :: (String, Transform) -> Transform
146 subnotapplied trans (App a b) = do
147   a' <- subnotapplied trans a
148   b' <- notapplied trans b
149   return $ App a' b'
150
151 -- Let subeverywhere handle all other expressions
152 subnotapplied trans expr = subeverywhere (notapplied trans) expr
153
154 -- Runs each of the transforms repeatedly inside the State monad.
155 dotransforms :: [Transform] -> CoreExpr -> TransformSession CoreExpr
156 dotransforms transs expr = do
157   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
158   if Monoid.getAny changed then dotransforms transs expr' else return expr'
159
160 -- Inline all let bindings that satisfy the given condition
161 inlinebind :: ((CoreBndr, CoreExpr) -> Bool) -> Transform
162 inlinebind condition (Let (Rec binds) expr) | not $ null replace =
163     change newexpr
164   where 
165     -- Find all simple bindings
166     (replace, others) = List.partition condition binds
167     -- Substitute the to be replaced binders with their expression
168     newexpr = substitute replace (Let (Rec others) expr)
169 -- Leave all other expressions unchanged
170 inlinebind _ expr = return expr
171
172 -- Sets the changed flag in the TransformMonad, to signify that some
173 -- transform has changed the result
174 setChanged :: TransformMonad ()
175 setChanged = Writer.tell (Monoid.Any True)
176
177 -- Sets the changed flag and returns the given value.
178 change :: a -> TransformMonad a
179 change val = do
180   setChanged
181   return val
182
183 -- Create a new Unique
184 mkUnique :: TransformMonad Unique.Unique
185 mkUnique = Trans.lift $ do
186     us <- getA tsUniqSupply 
187     let (us', us'') = UniqSupply.splitUniqSupply us
188     putA tsUniqSupply us'
189     return $ UniqSupply.uniqFromSupply us''
190
191 -- Replace each of the binders given with the coresponding expressions in the
192 -- given expression.
193 substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
194 substitute replace expr = CoreSubst.substExpr subs expr
195     where subs = foldl (\s (b, e) -> CoreSubst.extendSubst s b e) CoreSubst.emptySubst replace
196
197 -- Run a given TransformSession. Used mostly to setup the right calls and
198 -- an initial state.
199 runTransformSession :: UniqSupply.UniqSupply -> TransformSession a -> a
200 runTransformSession uniqSupply session = State.evalState session initState
201                        where initState = TransformState uniqSupply Map.empty VarSet.emptyVarSet