Don't propagate types with free tyvars.
[matthijs/master-project/cλash.git] / Normalize.hs
1 {-# LANGUAGE PackageImports #-}
2 --
3 -- Functions to bring a Core expression in normal form. This module provides a
4 -- top level function "normalize", and defines the actual transformation passes that
5 -- are performed.
6 --
7 module Normalize (normalizeModule) where
8
9 -- Standard modules
10 import Debug.Trace
11 import qualified Maybe
12 import qualified "transformers" Control.Monad.Trans as Trans
13 import qualified Control.Monad as Monad
14 import qualified Data.Map as Map
15 import Data.Accessor
16
17 -- GHC API
18 import CoreSyn
19 import qualified UniqSupply
20 import qualified CoreUtils
21 import qualified Type
22 import qualified Id
23 import qualified Var
24 import qualified VarSet
25 import qualified CoreFVs
26 import Outputable ( showSDoc, ppr, nest )
27
28 -- Local imports
29 import NormalizeTypes
30 import NormalizeTools
31 import CoreTools
32
33 --------------------------------
34 -- Start of transformations
35 --------------------------------
36
37 --------------------------------
38 -- η abstraction
39 --------------------------------
40 eta, etatop :: Transform
41 eta expr | is_fun expr && not (is_lam expr) = do
42   let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
43   id <- mkInternalVar "param" arg_ty
44   change (Lam id (App expr (Var id)))
45 -- Leave all other expressions unchanged
46 eta e = return e
47 etatop = notapplied ("eta", eta)
48
49 --------------------------------
50 -- β-reduction
51 --------------------------------
52 beta, betatop :: Transform
53 -- Substitute arg for x in expr
54 beta (App (Lam x expr) arg) = change $ substitute [(x, arg)] expr
55 -- Propagate the application into the let
56 beta (App (Let binds expr) arg) = change $ Let binds (App expr arg)
57 -- Propagate the application into each of the alternatives
58 beta (App (Case scrut b ty alts) arg) = change $ Case scrut b ty' alts'
59   where 
60     alts' = map (\(con, bndrs, expr) -> (con, bndrs, (App expr arg))) alts
61     (_, ty') = Type.splitFunTy ty
62 -- Leave all other expressions unchanged
63 beta expr = return expr
64 -- Perform this transform everywhere
65 betatop = everywhere ("beta", beta)
66
67 --------------------------------
68 -- let recursification
69 --------------------------------
70 letrec, letrectop :: Transform
71 letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
72 -- Leave all other expressions unchanged
73 letrec expr = return expr
74 -- Perform this transform everywhere
75 letrectop = everywhere ("letrec", letrec)
76
77 --------------------------------
78 -- let simplification
79 --------------------------------
80 letsimpl, letsimpltop :: Transform
81 -- Don't simplifiy lets that are already simple
82 letsimpl expr@(Let _ (Var _)) = return expr
83 -- Put the "in ..." value of a let in its own binding, but not when the
84 -- expression is applicable (to prevent loops with inlinefun).
85 letsimpl (Let (Rec binds) expr) | not $ is_applicable expr = do
86   id <- mkInternalVar "foo" (CoreUtils.exprType expr)
87   let bind = (id, expr)
88   change $ Let (Rec (bind:binds)) (Var id)
89 -- Leave all other expressions unchanged
90 letsimpl expr = return expr
91 -- Perform this transform everywhere
92 letsimpltop = everywhere ("letsimpl", letsimpl)
93
94 --------------------------------
95 -- let flattening
96 --------------------------------
97 letflat, letflattop :: Transform
98 letflat (Let (Rec binds) expr) = do
99   -- Turn each binding into a list of bindings (possibly containing just one
100   -- element, of course)
101   bindss <- Monad.mapM flatbind binds
102   -- Concat all the bindings
103   let binds' = concat bindss
104   -- Return the new let. We don't use change here, since possibly nothing has
105   -- changed. If anything has changed, flatbind has already flagged that
106   -- change.
107   return $ Let (Rec binds') expr
108   where
109     -- Turns a binding of a let into a multiple bindings, or any other binding
110     -- into a list with just that binding
111     flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
112     flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
113     flatbind (b, expr) = return [(b, expr)]
114 -- Leave all other expressions unchanged
115 letflat expr = return expr
116 -- Perform this transform everywhere
117 letflattop = everywhere ("letflat", letflat)
118
119 --------------------------------
120 -- Simple let binding removal
121 --------------------------------
122 -- Remove a = b bindings from let expressions everywhere
123 letremovetop :: Transform
124 letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> case e of (Var v) -> True; otherwise -> False))
125
126 --------------------------------
127 -- Function inlining
128 --------------------------------
129 -- Remove a = B bindings, with B :: a -> b, or B :: forall x . T, from let
130 -- expressions everywhere. This means that any value that still needs to be
131 -- applied to something else (polymorphic values need to be applied to a
132 -- Type) will be inlined, and will eventually be applied to all their
133 -- arguments.
134 --
135 -- This is a tricky function, which is prone to create loops in the
136 -- transformations. To fix this, we make sure that no transformation will
137 -- create a new let binding with a function type. These other transformations
138 -- will just not work on those function-typed values at first, but the other
139 -- transformations (in particular β-reduction) should make sure that the type
140 -- of those values eventually becomes primitive.
141 inlinefuntop :: Transform
142 inlinefuntop = everywhere ("inlinefun", inlinebind (is_applicable . snd))
143
144 --------------------------------
145 -- Scrutinee simplification
146 --------------------------------
147 scrutsimpl,scrutsimpltop :: Transform
148 -- Don't touch scrutinees that are already simple
149 scrutsimpl expr@(Case (Var _) _ _ _) = return expr
150 -- Replace all other cases with a let that binds the scrutinee and a new
151 -- simple scrutinee, but not when the scrutinee is applicable (to prevent
152 -- loops with inlinefun, though I don't think a scrutinee can be
153 -- applicable...)
154 scrutsimpl (Case scrut b ty alts) | not $ is_applicable scrut = do
155   id <- mkInternalVar "scrut" (CoreUtils.exprType scrut)
156   change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
157 -- Leave all other expressions unchanged
158 scrutsimpl expr = return expr
159 -- Perform this transform everywhere
160 scrutsimpltop = everywhere ("scrutsimpl", scrutsimpl)
161
162 --------------------------------
163 -- Case binder wildening
164 --------------------------------
165 casewild, casewildtop :: Transform
166 casewild expr@(Case scrut b ty alts) = do
167   (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts
168   let bindings = concat bindingss
169   -- Replace the case with a let with bindings and a case
170   let newlet = (Let (Rec bindings) (Case scrut b ty alts'))
171   -- If there are no non-wild binders, or this case is already a simple
172   -- selector (i.e., a single alt with exactly one binding), already a simple
173   -- selector altan no bindings (i.e., no wild binders in the original case),
174   -- don't change anything, otherwise, replace the case.
175   if null bindings || length alts == 1 && length bindings == 1 then return expr else change newlet 
176   where
177   -- Generate a single wild binder, since they are all the same
178   wild = Id.mkWildId
179   -- Wilden the binders of one alt, producing a list of bindings as a
180   -- sideeffect.
181   doalt :: CoreAlt -> TransformMonad ([(CoreBndr, CoreExpr)], CoreAlt)
182   doalt (con, bndrs, expr) = do
183     bindings_maybe <- Monad.zipWithM mkextracts bndrs [0..]
184     let bindings = Maybe.catMaybes bindings_maybe
185     -- We replace the binders with wild binders only. We can leave expr
186     -- unchanged, since the new bindings bind the same vars as the original
187     -- did.
188     let newalt = (con, wildbndrs, expr)
189     return (bindings, newalt)
190     where
191       -- Make all binders wild
192       wildbndrs = map (\bndr -> Id.mkWildId (Id.idType bndr)) bndrs
193       -- Creates a case statement to retrieve the ith element from the scrutinee
194       -- and binds that to b.
195       mkextracts :: CoreBndr -> Int -> TransformMonad (Maybe (CoreBndr, CoreExpr))
196       mkextracts b i =
197         if is_wild b || Type.isFunTy (Id.idType b) 
198           -- Don't create extra bindings for binders that are already wild, or
199           -- for binders that bind function types (to prevent loops with
200           -- inlinefun).
201           then return Nothing
202           else do
203             -- Create on new binder that will actually capture a value in this
204             -- case statement, and return it
205             let bty = (Id.idType b)
206             id <- mkInternalVar "sel" bty
207             let binders = take i wildbndrs ++ [id] ++ drop (i+1) wildbndrs
208             return $ Just (b, Case scrut b bty [(con, binders, Var id)])
209 -- Leave all other expressions unchanged
210 casewild expr = return expr
211 -- Perform this transform everywhere
212 casewildtop = everywhere ("casewild", casewild)
213
214 --------------------------------
215 -- Case value simplification
216 --------------------------------
217 casevalsimpl, casevalsimpltop :: Transform
218 casevalsimpl expr@(Case scrut b ty alts) = do
219   -- Try to simplify each alternative, resulting in an optional binding and a
220   -- new alternative.
221   (bindings_maybe, alts') <- (Monad.liftM unzip) $ mapM doalt alts
222   let bindings = Maybe.catMaybes bindings_maybe
223   -- Create a new let around the case, that binds of the cases values.
224   let newlet = Let (Rec bindings) (Case scrut b ty alts')
225   -- If there were no values that needed and allowed simplification, don't
226   -- change the case.
227   if null bindings then return expr else change newlet 
228   where
229     doalt :: CoreAlt -> TransformMonad (Maybe (CoreBndr, CoreExpr), CoreAlt)
230     -- Don't simplify values that are already simple
231     doalt alt@(con, bndrs, Var _) = return (Nothing, alt)
232     -- Simplify each alt by creating a new id, binding the case value to it and
233     -- replacing the case value with that id. Only do this when the case value
234     -- does not use any of the binders bound by this alternative, for that would
235     -- cause those binders to become unbound when moving the value outside of
236     -- the case statement. Also, don't create a binding for applicable
237     -- expressions, to prevent loops with inlinefun.
238     doalt (con, bndrs, expr) | (not usesvars) && (not $ is_applicable expr) = do
239       id <- mkInternalVar "caseval" (CoreUtils.exprType expr)
240       -- We don't flag a change here, since casevalsimpl will do that above
241       -- based on Just we return here.
242       return $ (Just (id, expr), (con, bndrs, Var id))
243       -- Find if any of the binders are used by expr
244       where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
245     -- Don't simplify anything else
246     doalt alt = return (Nothing, alt)
247 -- Leave all other expressions unchanged
248 casevalsimpl expr = return expr
249 -- Perform this transform everywhere
250 casevalsimpltop = everywhere ("casevalsimpl", casevalsimpl)
251
252 --------------------------------
253 -- Case removal
254 --------------------------------
255 -- Remove case statements that have only a single alternative and only wild
256 -- binders.
257 caseremove, caseremovetop :: Transform
258 -- Replace a useless case by the value of its single alternative
259 caseremove (Case scrut b ty [(con, bndrs, expr)]) | not usesvars = change expr
260     -- Find if any of the binders are used by expr
261     where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
262 -- Leave all other expressions unchanged
263 caseremove expr = return expr
264 -- Perform this transform everywhere
265 caseremovetop = everywhere ("caseremove", caseremove)
266
267 --------------------------------
268 -- Application simplification
269 --------------------------------
270 -- Make sure that all arguments in an application are simple variables.
271 appsimpl, appsimpltop :: Transform
272 -- Don't simplify arguments that are already simple
273 appsimpl expr@(App f (Var _)) = return expr
274 -- Simplify all non-applicable (to prevent loops with inlinefun) arguments,
275 -- except for type arguments (since a let can't bind type vars, only a lambda
276 -- can). Do this by introducing a new Let that binds the argument and passing
277 -- the new binder in the application.
278 appsimpl (App f expr) | (not $ is_applicable expr) && (not $ CoreSyn.isTypeArg expr) = do
279   id <- mkInternalVar "arg" (CoreUtils.exprType expr)
280   change $ Let (Rec [(id, expr)]) (App f (Var id))
281 -- Leave all other expressions unchanged
282 appsimpl expr = return expr
283 -- Perform this transform everywhere
284 appsimpltop = everywhere ("appsimpl", appsimpl)
285
286
287 --------------------------------
288 -- Type argument propagation
289 --------------------------------
290 -- Remove all applications to type arguments, by duplicating the function
291 -- called with the type application in its new definition. We leave
292 -- dictionaries that might be associated with the type untouched, the funprop
293 -- transform should propagate these later on.
294 typeprop, typeproptop :: Transform
295 -- Transform any function that is applied to a type argument. Since type
296 -- arguments are always the first ones to apply and we'll remove all type
297 -- arguments, we can simply do them one by one. We only propagate type
298 -- arguments without any free tyvars, since tyvars those wouldn't be in scope
299 -- in the new function.
300 typeprop expr@(App (Var f) arg@(Type ty)) | not $ has_free_tyvars arg = do
301   id <- cloneVar f
302   let newty = Type.applyTy (Id.idType f) ty
303   let newf = Var.setVarType id newty
304   body_maybe <- Trans.lift $ getGlobalBind f
305   case body_maybe of
306     Just body -> do
307       let newbody = App body (Type ty)
308       Trans.lift $ addGlobalBind newf newbody
309       change (Var newf)
310     -- If we don't have a body for the function called, leave it unchanged (it
311     -- should be a primitive function then).
312     Nothing -> return expr
313 -- Leave all other expressions unchanged
314 typeprop expr = return expr
315 -- Perform this transform everywhere
316 typeproptop = everywhere ("typeprop", typeprop)
317
318 -- TODO: introduce top level let if needed?
319
320 --------------------------------
321 -- End of transformations
322 --------------------------------
323
324
325
326
327 -- What transforms to run?
328 transforms = [typeproptop, etatop, betatop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinefuntop, appsimpltop]
329
330 -- Turns the given bind into VHDL
331 normalizeModule :: 
332   UniqSupply.UniqSupply -- ^ A UniqSupply we can use
333   -> [(CoreBndr, CoreExpr)]  -- ^ All bindings we know (i.e., in the current module)
334   -> [CoreBndr]  -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
335   -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
336   -> [(CoreBndr, CoreExpr)] -- ^ The resulting VHDL
337
338 normalizeModule uniqsupply bindings generate_for statefuls = runTransformSession uniqsupply $ do
339   -- Put all the bindings in this module in the tsBindings map
340   putA tsBindings (Map.fromList bindings)
341   -- (Recursively) normalize each of the requested bindings
342   mapM normalizeBind generate_for
343   -- Get all initial bindings and the ones we produced
344   bindings_map <- getA tsBindings
345   let bindings = Map.assocs bindings_map
346   normalized_bindings <- getA tsNormalized
347   -- But return only the normalized bindings
348   return $ filter ((flip VarSet.elemVarSet normalized_bindings) . fst) bindings
349
350 normalizeBind :: CoreBndr -> TransformSession ()
351 normalizeBind bndr = do
352   normalized_funcs <- getA tsNormalized
353   -- See if this function was normalized already
354   if VarSet.elemVarSet bndr normalized_funcs
355     then
356       -- Yup, don't do it again
357       return ()
358     else do
359       -- Nope, note that it has been and do it.
360       modA tsNormalized (flip VarSet.extendVarSet bndr)
361       expr_maybe <- getGlobalBind bndr
362       case expr_maybe of 
363         Just expr -> do
364           -- Normalize this expression
365           expr' <- dotransforms transforms expr
366           let expr'' = trace ("Before:\n\n" ++ showSDoc ( ppr expr ) ++ "\n\nAfter:\n\n" ++ showSDoc ( ppr expr')) expr'
367           -- And store the normalized version in the session
368           modA tsBindings (Map.insert bndr expr'')
369           -- Find all vars used with a function type. All of these should be global
370           -- binders (i.e., functions used), since any local binders with a function
371           -- type should have been inlined already.
372           let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> (Type.isFunTy . snd . Type.splitForAllTys . Id.idType) v) expr''
373           let used_funcs = VarSet.varSetElems used_funcs_set
374           -- Process each of the used functions recursively
375           mapM normalizeBind used_funcs
376           return ()
377         -- We don't have a value for this binder, let's assume this is a builtin
378         -- function. This might need some extra checking and a nice error
379         -- message).
380         Nothing -> return ()