Add predicates for testing representability of types.
[matthijs/master-project/cλash.git] / NormalizeTools.hs
1 {-# LANGUAGE PackageImports #-}
2 -- 
3 -- This module provides functions for program transformations.
4 --
5 module NormalizeTools where
6 -- Standard modules
7 import Debug.Trace
8 import qualified List
9 import qualified Data.Monoid as Monoid
10 import qualified Control.Arrow as Arrow
11 import qualified Control.Monad as Monad
12 import qualified Control.Monad.Trans.State as State
13 import qualified Control.Monad.Trans.Writer as Writer
14 import qualified "transformers" Control.Monad.Trans as Trans
15 import qualified Data.Map as Map
16 import Data.Accessor
17 import Data.Accessor.MonadState as MonadState
18
19 -- GHC API
20 import CoreSyn
21 import qualified UniqSupply
22 import qualified Unique
23 import qualified OccName
24 import qualified Name
25 import qualified Var
26 import qualified SrcLoc
27 import qualified Type
28 import qualified IdInfo
29 import qualified CoreUtils
30 import qualified CoreSubst
31 import qualified VarSet
32 import Outputable ( showSDoc, ppr, nest )
33
34 -- Local imports
35 import NormalizeTypes
36 import Pretty
37 import qualified VHDLTools
38
39 -- Create a new internal var with the given name and type. A Unique is
40 -- appended to the given name, to ensure uniqueness (not strictly neccesary,
41 -- since the Unique is also stored in the name, but this ensures variable
42 -- names are unique in the output).
43 mkInternalVar :: String -> Type.Type -> TransformMonad Var.Var
44 mkInternalVar str ty = do
45   uniq <- mkUnique
46   let occname = OccName.mkVarOcc (str ++ show uniq)
47   let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
48   return $ Var.mkLocalIdVar name ty IdInfo.vanillaIdInfo
49
50 -- Create a new type variable with the given name and kind. A Unique is
51 -- appended to the given name, to ensure uniqueness (not strictly neccesary,
52 -- since the Unique is also stored in the name, but this ensures variable
53 -- names are unique in the output).
54 mkTypeVar :: String -> Type.Kind -> TransformMonad Var.Var
55 mkTypeVar str kind = do
56   uniq <- mkUnique
57   let occname = OccName.mkVarOcc (str ++ show uniq)
58   let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
59   return $ Var.mkTyVar name kind
60
61 -- Creates a binder for the given expression with the given name. This
62 -- works for both value and type level expressions, so it can return a Var or
63 -- TyVar (which is just an alias for Var).
64 mkBinderFor :: CoreExpr -> String -> TransformMonad Var.Var
65 mkBinderFor (Type ty) string = mkTypeVar string (Type.typeKind ty)
66 mkBinderFor expr string = mkInternalVar string (CoreUtils.exprType expr)
67
68 -- Creates a reference to the given variable. This works for both a normal
69 -- variable as well as a type variable
70 mkReferenceTo :: Var.Var -> CoreExpr
71 mkReferenceTo var | Var.isTyVar var = (Type $ Type.mkTyVarTy var)
72                   | otherwise       = (Var var)
73
74 cloneVar :: Var.Var -> TransformMonad Var.Var
75 cloneVar v = do
76   uniq <- mkUnique
77   -- Swap out the unique, and reset the IdInfo (I'm not 100% sure what it
78   -- contains, but vannillaIdInfo is always correct, since it means "no info").
79   return $ Var.lazySetVarIdInfo (Var.setVarUnique v uniq) IdInfo.vanillaIdInfo
80
81 -- Creates a new function with the same name as the given binder (but with a
82 -- new unique) and with the given function body. Returns the new binder for
83 -- this function.
84 mkFunction :: CoreBndr -> CoreExpr -> TransformMonad CoreBndr
85 mkFunction bndr body = do
86   let ty = CoreUtils.exprType body
87   id <- cloneVar bndr
88   let newid = Var.setVarType id ty
89   Trans.lift $ addGlobalBind newid body
90   return newid
91
92 -- Apply the given transformation to all expressions in the given expression,
93 -- including the expression itself.
94 everywhere :: (String, Transform) -> Transform
95 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
96
97 -- Apply the first transformation, followed by the second transformation, and
98 -- keep applying both for as long as expression still changes.
99 applyboth :: Transform -> (String, Transform) -> Transform
100 applyboth first (name, second) expr  = do
101   -- Apply the first
102   expr' <- first expr
103   -- Apply the second
104   (expr'', changed) <- Writer.listen $ second expr'
105   if Monoid.getAny $
106   --      trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
107         changed 
108     then 
109 --      trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
110  --     trace ("Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
111       applyboth first (name, second) $
112         expr'' 
113     else 
114     --  trace ("No changes") $
115       return expr''
116
117 -- Apply the given transformation to all direct subexpressions (only), not the
118 -- expression itself.
119 subeverywhere :: Transform -> Transform
120 subeverywhere trans (App a b) = do
121   a' <- trans a
122   b' <- trans b
123   return $ App a' b'
124
125 subeverywhere trans (Let (NonRec b bexpr) expr) = do
126   bexpr' <- trans bexpr
127   expr' <- trans expr
128   return $ Let (NonRec b bexpr') expr'
129
130 subeverywhere trans (Let (Rec binds) expr) = do
131   expr' <- trans expr
132   binds' <- mapM transbind binds
133   return $ Let (Rec binds') expr'
134   where
135     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
136     transbind (b, e) = do
137       e' <- trans e
138       return (b, e')
139
140 subeverywhere trans (Lam x expr) = do
141   expr' <- trans expr
142   return $ Lam x expr'
143
144 subeverywhere trans (Case scrut b t alts) = do
145   scrut' <- trans scrut
146   alts' <- mapM transalt alts
147   return $ Case scrut' b t alts'
148   where
149     transalt :: CoreAlt -> TransformMonad CoreAlt
150     transalt (con, binders, expr) = do
151       expr' <- trans expr
152       return (con, binders, expr')
153
154 subeverywhere trans (Var x) = return $ Var x
155 subeverywhere trans (Lit x) = return $ Lit x
156 subeverywhere trans (Type x) = return $ Type x
157
158 subeverywhere trans (Cast expr ty) = do
159   expr' <- trans expr
160   return $ Cast expr' ty
161
162 subeverywhere trans expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
163
164 -- Apply the given transformation to all expressions, except for direct
165 -- arguments of an application
166 notappargs :: (String, Transform) -> Transform
167 notappargs trans = applyboth (subnotappargs trans) trans
168
169 -- Apply the given transformation to all (direct and indirect) subexpressions
170 -- (but not the expression itself), except for direct arguments of an
171 -- application
172 subnotappargs :: (String, Transform) -> Transform
173 subnotappargs trans (App a b) = do
174   a' <- subnotappargs trans a
175   b' <- subnotappargs trans b
176   return $ App a' b'
177
178 -- Let subeverywhere handle all other expressions
179 subnotappargs trans expr = subeverywhere (notappargs trans) expr
180
181 -- Runs each of the transforms repeatedly inside the State monad.
182 dotransforms :: [Transform] -> CoreExpr -> TransformSession CoreExpr
183 dotransforms transs expr = do
184   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
185   if Monoid.getAny changed then dotransforms transs expr' else return expr'
186
187 -- Inline all let bindings that satisfy the given condition
188 inlinebind :: ((CoreBndr, CoreExpr) -> Bool) -> Transform
189 inlinebind condition (Let (Rec binds) expr) | not $ null replace =
190     change newexpr
191   where 
192     -- Find all simple bindings
193     (replace, others) = List.partition condition binds
194     -- Substitute the to be replaced binders with their expression
195     newexpr = substitute replace (Let (Rec others) expr)
196 -- Leave all other expressions unchanged
197 inlinebind _ expr = return expr
198
199 -- Sets the changed flag in the TransformMonad, to signify that some
200 -- transform has changed the result
201 setChanged :: TransformMonad ()
202 setChanged = Writer.tell (Monoid.Any True)
203
204 -- Sets the changed flag and returns the given value.
205 change :: a -> TransformMonad a
206 change val = do
207   setChanged
208   return val
209
210 -- Create a new Unique
211 mkUnique :: TransformMonad Unique.Unique
212 mkUnique = Trans.lift $ do
213     us <- getA tsUniqSupply 
214     let (us', us'') = UniqSupply.splitUniqSupply us
215     putA tsUniqSupply us'
216     return $ UniqSupply.uniqFromSupply us''
217
218 -- Replace each of the binders given with the coresponding expressions in the
219 -- given expression.
220 substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
221 substitute [] expr = expr
222 -- Apply one substitution on the expression, but also on any remaining
223 -- substitutions. This seems to be the only way to handle substitutions like
224 -- [(b, c), (a, b)]. This means we reuse a substitution, which is not allowed
225 -- according to CoreSubst documentation (but it doesn't seem to be a problem).
226 -- TODO: Find out how this works, exactly.
227 substitute ((b, e):subss) expr = substitute subss' expr'
228   where 
229     -- Create the Subst
230     subs = (CoreSubst.extendSubst CoreSubst.emptySubst b e)
231     -- Apply this substitution to the main expression
232     expr' = CoreSubst.substExpr subs expr
233     -- Apply this substitution on all the expressions in the remaining
234     -- substitutions
235     subss' = map (Arrow.second (CoreSubst.substExpr subs)) subss
236
237 -- Run a given TransformSession. Used mostly to setup the right calls and
238 -- an initial state.
239 runTransformSession :: UniqSupply.UniqSupply -> TransformSession a -> a
240 runTransformSession uniqSupply session = State.evalState session (emptyTransformState uniqSupply)
241
242 -- Is the given expression representable at runtime, based on the type?
243 isRepr :: CoreSyn.CoreExpr -> TransformMonad Bool
244 isRepr (Type ty) = return False
245 isRepr expr = Trans.lift $ MonadState.lift tsType $ VHDLTools.isReprType (CoreUtils.exprType expr)