Recursively normalize binds.
[matthijs/master-project/cλash.git] / Normalize.hs
1 --
2 -- Functions to bring a Core expression in normal form. This module provides a
3 -- top level function "normalize", and defines the actual transformation passes that
4 -- are performed.
5 --
6 module Normalize (normalizeModule) where
7
8 -- Standard modules
9 import Debug.Trace
10 import qualified Maybe
11 import qualified Control.Monad as Monad
12 import qualified Data.Map as Map
13 import Data.Accessor
14
15 -- GHC API
16 import CoreSyn
17 import qualified UniqSupply
18 import qualified CoreUtils
19 import qualified Type
20 import qualified Id
21 import qualified VarSet
22 import qualified CoreFVs
23 import Outputable ( showSDoc, ppr, nest )
24
25 -- Local imports
26 import NormalizeTypes
27 import NormalizeTools
28 import CoreTools
29
30 --------------------------------
31 -- Start of transformations
32 --------------------------------
33
34 --------------------------------
35 -- η abstraction
36 --------------------------------
37 eta, etatop :: Transform
38 eta expr | is_fun expr && not (is_lam expr) = do
39   let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
40   id <- mkInternalVar "param" arg_ty
41   change (Lam id (App expr (Var id)))
42 -- Leave all other expressions unchanged
43 eta e = return e
44 etatop = notapplied ("eta", eta)
45
46 --------------------------------
47 -- β-reduction
48 --------------------------------
49 beta, betatop :: Transform
50 -- Substitute arg for x in expr
51 beta (App (Lam x expr) arg) = change $ substitute [(x, arg)] expr
52 -- Propagate the application into the let
53 beta (App (Let binds expr) arg) = change $ Let binds (App expr arg)
54 -- Propagate the application into each of the alternatives
55 beta (App (Case scrut b ty alts) arg) = change $ Case scrut b ty' alts'
56   where 
57     alts' = map (\(con, bndrs, expr) -> (con, bndrs, (App expr arg))) alts
58     (_, ty') = Type.splitFunTy ty
59 -- Leave all other expressions unchanged
60 beta expr = return expr
61 -- Perform this transform everywhere
62 betatop = everywhere ("beta", beta)
63
64 --------------------------------
65 -- let recursification
66 --------------------------------
67 letrec, letrectop :: Transform
68 letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
69 -- Leave all other expressions unchanged
70 letrec expr = return expr
71 -- Perform this transform everywhere
72 letrectop = everywhere ("letrec", letrec)
73
74 --------------------------------
75 -- let simplification
76 --------------------------------
77 letsimpl, letsimpltop :: Transform
78 -- Don't simplifiy lets that are already simple
79 letsimpl expr@(Let _ (Var _)) = return expr
80 -- Put the "in ..." value of a let in its own binding, but not when the
81 -- expression has a function type (to prevent loops with inlinefun).
82 letsimpl (Let (Rec binds) expr) | not $ is_fun expr = do
83   id <- mkInternalVar "foo" (CoreUtils.exprType expr)
84   let bind = (id, expr)
85   change $ Let (Rec (bind:binds)) (Var id)
86 -- Leave all other expressions unchanged
87 letsimpl expr = return expr
88 -- Perform this transform everywhere
89 letsimpltop = everywhere ("letsimpl", letsimpl)
90
91 --------------------------------
92 -- let flattening
93 --------------------------------
94 letflat, letflattop :: Transform
95 letflat (Let (Rec binds) expr) = do
96   -- Turn each binding into a list of bindings (possibly containing just one
97   -- element, of course)
98   bindss <- Monad.mapM flatbind binds
99   -- Concat all the bindings
100   let binds' = concat bindss
101   -- Return the new let. We don't use change here, since possibly nothing has
102   -- changed. If anything has changed, flatbind has already flagged that
103   -- change.
104   return $ Let (Rec binds') expr
105   where
106     -- Turns a binding of a let into a multiple bindings, or any other binding
107     -- into a list with just that binding
108     flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
109     flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
110     flatbind (b, expr) = return [(b, expr)]
111 -- Leave all other expressions unchanged
112 letflat expr = return expr
113 -- Perform this transform everywhere
114 letflattop = everywhere ("letflat", letflat)
115
116 --------------------------------
117 -- Simple let binding removal
118 --------------------------------
119 -- Remove a = b bindings from let expressions everywhere
120 letremovetop :: Transform
121 letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> case e of (Var v) -> True; otherwise -> False))
122
123 --------------------------------
124 -- Function inlining
125 --------------------------------
126 -- Remove a = B bindings, with B :: a -> b, from let expressions everywhere.
127 -- This is a tricky function, which is prone to create loops in the
128 -- transformations. To fix this, we make sure that no transformation will
129 -- create a new let binding with a function type. These other transformations
130 -- will just not work on those function-typed values at first, but the other
131 -- transformations (in particular β-reduction) should make sure that the type
132 -- of those values eventually becomes primitive.
133 inlinefuntop :: Transform
134 inlinefuntop = everywhere ("inlinefun", inlinebind (Type.isFunTy . CoreUtils.exprType . snd))
135
136 --------------------------------
137 -- Scrutinee simplification
138 --------------------------------
139 scrutsimpl,scrutsimpltop :: Transform
140 -- Don't touch scrutinees that are already simple
141 scrutsimpl expr@(Case (Var _) _ _ _) = return expr
142 -- Replace all other cases with a let that binds the scrutinee and a new
143 -- simple scrutinee, but not when the scrutinee is a function type (to prevent
144 -- loops with inlinefun, though I don't think a scrutinee can have a function
145 -- type...)
146 scrutsimpl (Case scrut b ty alts) | not $ is_fun scrut = do
147   id <- mkInternalVar "scrut" (CoreUtils.exprType scrut)
148   change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
149 -- Leave all other expressions unchanged
150 scrutsimpl expr = return expr
151 -- Perform this transform everywhere
152 scrutsimpltop = everywhere ("scrutsimpl", scrutsimpl)
153
154 --------------------------------
155 -- Case binder wildening
156 --------------------------------
157 casewild, casewildtop :: Transform
158 casewild expr@(Case scrut b ty alts) = do
159   (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts
160   let bindings = concat bindingss
161   -- Replace the case with a let with bindings and a case
162   let newlet = (Let (Rec bindings) (Case scrut b ty alts'))
163   -- If there are no non-wild binders, or this case is already a simple
164   -- selector (i.e., a single alt with exactly one binding), already a simple
165   -- selector altan no bindings (i.e., no wild binders in the original case),
166   -- don't change anything, otherwise, replace the case.
167   if null bindings || length alts == 1 && length bindings == 1 then return expr else change newlet 
168   where
169   -- Generate a single wild binder, since they are all the same
170   wild = Id.mkWildId
171   -- Wilden the binders of one alt, producing a list of bindings as a
172   -- sideeffect.
173   doalt :: CoreAlt -> TransformMonad ([(CoreBndr, CoreExpr)], CoreAlt)
174   doalt (con, bndrs, expr) = do
175     bindings_maybe <- Monad.zipWithM mkextracts bndrs [0..]
176     let bindings = Maybe.catMaybes bindings_maybe
177     -- We replace the binders with wild binders only. We can leave expr
178     -- unchanged, since the new bindings bind the same vars as the original
179     -- did.
180     let newalt = (con, wildbndrs, expr)
181     return (bindings, newalt)
182     where
183       -- Make all binders wild
184       wildbndrs = map (\bndr -> Id.mkWildId (Id.idType bndr)) bndrs
185       -- Creates a case statement to retrieve the ith element from the scrutinee
186       -- and binds that to b.
187       mkextracts :: CoreBndr -> Int -> TransformMonad (Maybe (CoreBndr, CoreExpr))
188       mkextracts b i =
189         if is_wild b || Type.isFunTy (Id.idType b) 
190           -- Don't create extra bindings for binders that are already wild, or
191           -- for binders that bind function types (to prevent loops with
192           -- inlinefun).
193           then return Nothing
194           else do
195             -- Create on new binder that will actually capture a value in this
196             -- case statement, and return it
197             let bty = (Id.idType b)
198             id <- mkInternalVar "sel" bty
199             let binders = take i wildbndrs ++ [id] ++ drop (i+1) wildbndrs
200             return $ Just (b, Case scrut b bty [(con, binders, Var id)])
201 -- Leave all other expressions unchanged
202 casewild expr = return expr
203 -- Perform this transform everywhere
204 casewildtop = everywhere ("casewild", casewild)
205
206 --------------------------------
207 -- Case value simplification
208 --------------------------------
209 casevalsimpl, casevalsimpltop :: Transform
210 casevalsimpl expr@(Case scrut b ty alts) = do
211   -- Try to simplify each alternative, resulting in an optional binding and a
212   -- new alternative.
213   (bindings_maybe, alts') <- (Monad.liftM unzip) $ mapM doalt alts
214   let bindings = Maybe.catMaybes bindings_maybe
215   -- Create a new let around the case, that binds of the cases values.
216   let newlet = Let (Rec bindings) (Case scrut b ty alts')
217   -- If there were no values that needed and allowed simplification, don't
218   -- change the case.
219   if null bindings then return expr else change newlet 
220   where
221     doalt :: CoreAlt -> TransformMonad (Maybe (CoreBndr, CoreExpr), CoreAlt)
222     -- Don't simplify values that are already simple
223     doalt alt@(con, bndrs, Var _) = return (Nothing, alt)
224     -- Simplify each alt by creating a new id, binding the case value to it and
225     -- replacing the case value with that id. Only do this when the case value
226     -- does not use any of the binders bound by this alternative, for that would
227     -- cause those binders to become unbound when moving the value outside of
228     -- the case statement. Also, don't create a binding for function-typed
229     -- expressions, to prevent loops with inlinefun.
230     doalt (con, bndrs, expr) | (not usesvars) && (not $ is_fun expr) = do
231       id <- mkInternalVar "caseval" (CoreUtils.exprType expr)
232       -- We don't flag a change here, since casevalsimpl will do that above
233       -- based on Just we return here.
234       return $ (Just (id, expr), (con, bndrs, Var id))
235       -- Find if any of the binders are used by expr
236       where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
237     -- Don't simplify anything else
238     doalt alt = return (Nothing, alt)
239 -- Leave all other expressions unchanged
240 casevalsimpl expr = return expr
241 -- Perform this transform everywhere
242 casevalsimpltop = everywhere ("casevalsimpl", casevalsimpl)
243
244 --------------------------------
245 -- Case removal
246 --------------------------------
247 -- Remove case statements that have only a single alternative and only wild
248 -- binders.
249 caseremove, caseremovetop :: Transform
250 -- Replace a useless case by the value of its single alternative
251 caseremove (Case scrut b ty [(con, bndrs, expr)]) | not usesvars = change expr
252     -- Find if any of the binders are used by expr
253     where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
254 -- Leave all other expressions unchanged
255 caseremove expr = return expr
256 -- Perform this transform everywhere
257 caseremovetop = everywhere ("caseremove", caseremove)
258
259 --------------------------------
260 -- Application simplification
261 --------------------------------
262 -- Make sure that all arguments in an application are simple variables.
263 appsimpl, appsimpltop :: Transform
264 -- Don't simplify arguments that are already simple
265 appsimpl expr@(App f (Var _)) = return expr
266 -- Simplify all arguments that do not have a function type (to prevent loops
267 -- with inlinefun) and is not a type argument. Do this by introducing a new
268 -- Let that binds the argument and passing the new binder in the application.
269 appsimpl (App f expr) | (not $ is_fun expr) && (not $ CoreSyn.isTypeArg expr) = do
270   id <- mkInternalVar "arg" (CoreUtils.exprType expr)
271   change $ Let (Rec [(id, expr)]) (App f (Var id))
272 -- Leave all other expressions unchanged
273 appsimpl expr = return expr
274 -- Perform this transform everywhere
275 appsimpltop = everywhere ("appsimpl", appsimpl)
276
277 -- TODO: introduce top level let if needed?
278
279 --------------------------------
280 -- End of transformations
281 --------------------------------
282
283
284
285
286 -- What transforms to run?
287 transforms = [etatop, betatop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinefuntop, appsimpltop]
288
289 -- Turns the given bind into VHDL
290 normalizeModule :: 
291   UniqSupply.UniqSupply -- ^ A UniqSupply we can use
292   -> [(CoreBndr, CoreExpr)]  -- ^ All bindings we know (i.e., in the current module)
293   -> [CoreBndr]  -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
294   -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
295   -> [(CoreBndr, CoreExpr)] -- ^ The resulting VHDL
296
297 normalizeModule uniqsupply bindings generate_for statefuls = runTransformSession uniqsupply $ do
298   -- Put all the bindings in this module in the tsBindings map
299   putA tsBindings (Map.fromList bindings)
300   -- (Recursively) normalize each of the requested bindings
301   mapM normalizeBind generate_for
302   -- Get all initial bindings and the ones we produced
303   bindings_map <- getA tsBindings
304   let bindings = Map.assocs bindings_map
305   normalized_bindings <- getA tsNormalized
306   -- But return only the normalized bindings
307   return $ filter ((flip VarSet.elemVarSet normalized_bindings) . fst) bindings
308
309 normalizeBind :: CoreBndr -> TransformSession ()
310 normalizeBind bndr = do
311   normalized_funcs <- getA tsNormalized
312   -- See if this function was normalized already
313   if VarSet.elemVarSet bndr normalized_funcs
314     then
315       -- Yup, don't do it again
316       return ()
317     else do
318       -- Nope, note that it has been and do it.
319       modA tsNormalized (flip VarSet.extendVarSet bndr)
320       expr_maybe <- getGlobalBind bndr
321       case expr_maybe of 
322         Just expr -> do
323           -- Normalize this expression
324           expr' <- dotransforms transforms expr
325           let expr'' = trace ("Before:\n\n" ++ showSDoc ( ppr expr ) ++ "\n\nAfter:\n\n" ++ showSDoc ( ppr expr')) expr'
326           -- And store the normalized version in the session
327           modA tsBindings (Map.insert bndr expr'')
328           -- Find all vars used with a function type. All of these should be global
329           -- binders (i.e., functions used), since any local binders with a function
330           -- type should have been inlined already.
331           let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> trace (showSDoc $ ppr $ Id.idType v) ((Type.isFunTy . snd . Type.splitForAllTys . Id.idType)v)) expr''
332           let used_funcs = VarSet.varSetElems used_funcs_set
333           -- Process each of the used functions recursively
334           mapM normalizeBind (trace (show used_funcs) used_funcs)
335           return ()
336         -- We don't have a value for this binder, let's assume this is a builtin
337         -- function. This might need some extra checking and a nice error
338         -- message).
339         Nothing -> return ()