Adepted the modules to their new structure
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
1 {-# LANGUAGE PackageImports #-}
2 --
3 -- Functions to bring a Core expression in normal form. This module provides a
4 -- top level function "normalize", and defines the actual transformation passes that
5 -- are performed.
6 --
7 module CLasH.Normalize (normalizeModule) where
8
9 -- Standard modules
10 import Debug.Trace
11 import qualified Maybe
12 import qualified "transformers" Control.Monad.Trans as Trans
13 import qualified Control.Monad as Monad
14 import qualified Control.Monad.Trans.Writer as Writer
15 import qualified Data.Map as Map
16 import qualified Data.Monoid as Monoid
17 import Data.Accessor
18
19 -- GHC API
20 import CoreSyn
21 import qualified UniqSupply
22 import qualified CoreUtils
23 import qualified Type
24 import qualified TcType
25 import qualified Id
26 import qualified Var
27 import qualified VarSet
28 import qualified NameSet
29 import qualified CoreFVs
30 import qualified CoreUtils
31 import qualified MkCore
32 import qualified HscTypes
33 import Outputable ( showSDoc, ppr, nest )
34
35 -- Local imports
36 import CLasH.Normalize.NormalizeTypes
37 import CLasH.Normalize.NormalizeTools
38 import CLasH.VHDL.VHDLTypes
39 import CLasH.Utils.Core.CoreTools
40 import CLasH.Utils.Pretty
41
42 --------------------------------
43 -- Start of transformations
44 --------------------------------
45
46 --------------------------------
47 -- η abstraction
48 --------------------------------
49 eta, etatop :: Transform
50 eta expr | is_fun expr && not (is_lam expr) = do
51   let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
52   id <- mkInternalVar "param" arg_ty
53   change (Lam id (App expr (Var id)))
54 -- Leave all other expressions unchanged
55 eta e = return e
56 etatop = notappargs ("eta", eta)
57
58 --------------------------------
59 -- β-reduction
60 --------------------------------
61 beta, betatop :: Transform
62 -- Substitute arg for x in expr
63 beta (App (Lam x expr) arg) = change $ substitute [(x, arg)] expr
64 -- Propagate the application into the let
65 beta (App (Let binds expr) arg) = change $ Let binds (App expr arg)
66 -- Propagate the application into each of the alternatives
67 beta (App (Case scrut b ty alts) arg) = change $ Case scrut b ty' alts'
68   where 
69     alts' = map (\(con, bndrs, expr) -> (con, bndrs, (App expr arg))) alts
70     ty' = CoreUtils.applyTypeToArg ty arg
71 -- Leave all other expressions unchanged
72 beta expr = return expr
73 -- Perform this transform everywhere
74 betatop = everywhere ("beta", beta)
75
76 --------------------------------
77 -- Cast propagation
78 --------------------------------
79 -- Try to move casts as much downward as possible.
80 castprop, castproptop :: Transform
81 castprop (Cast (Let binds expr) ty) = change $ Let binds (Cast expr ty)
82 castprop expr@(Cast (Case scrut b _ alts) ty) = change (Case scrut b ty alts')
83   where
84     alts' = map (\(con, bndrs, expr) -> (con, bndrs, (Cast expr ty))) alts
85 -- Leave all other expressions unchanged
86 castprop expr = return expr
87 -- Perform this transform everywhere
88 castproptop = everywhere ("castprop", castprop)
89
90 --------------------------------
91 -- let recursification
92 --------------------------------
93 letrec, letrectop :: Transform
94 letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
95 -- Leave all other expressions unchanged
96 letrec expr = return expr
97 -- Perform this transform everywhere
98 letrectop = everywhere ("letrec", letrec)
99
100 --------------------------------
101 -- let simplification
102 --------------------------------
103 letsimpl, letsimpltop :: Transform
104 -- Put the "in ..." value of a let in its own binding, but not when the
105 -- expression is applicable (to prevent loops with inlinefun).
106 letsimpl expr@(Let (Rec binds) res) | not $ is_applicable expr = do
107   local_var <- Trans.lift $ is_local_var res
108   if not local_var
109     then do
110       -- If the result is not a local var already (to prevent loops with
111       -- ourselves), extract it.
112       id <- mkInternalVar "foo" (CoreUtils.exprType res)
113       let bind = (id, res)
114       change $ Let (Rec (bind:binds)) (Var id)
115     else
116       -- If the result is already a local var, don't extract it.
117       return expr
118
119 -- Leave all other expressions unchanged
120 letsimpl expr = return expr
121 -- Perform this transform everywhere
122 letsimpltop = everywhere ("letsimpl", letsimpl)
123
124 --------------------------------
125 -- let flattening
126 --------------------------------
127 letflat, letflattop :: Transform
128 letflat (Let (Rec binds) expr) = do
129   -- Turn each binding into a list of bindings (possibly containing just one
130   -- element, of course)
131   bindss <- Monad.mapM flatbind binds
132   -- Concat all the bindings
133   let binds' = concat bindss
134   -- Return the new let. We don't use change here, since possibly nothing has
135   -- changed. If anything has changed, flatbind has already flagged that
136   -- change.
137   return $ Let (Rec binds') expr
138   where
139     -- Turns a binding of a let into a multiple bindings, or any other binding
140     -- into a list with just that binding
141     flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
142     flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
143     flatbind (b, expr) = return [(b, expr)]
144 -- Leave all other expressions unchanged
145 letflat expr = return expr
146 -- Perform this transform everywhere
147 letflattop = everywhere ("letflat", letflat)
148
149 --------------------------------
150 -- Simple let binding removal
151 --------------------------------
152 -- Remove a = b bindings from let expressions everywhere
153 letremovetop :: Transform
154 letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
155
156 --------------------------------
157 -- Function inlining
158 --------------------------------
159 -- Remove a = B bindings, with B :: a -> b, or B :: forall x . T, from let
160 -- expressions everywhere. This means that any value that still needs to be
161 -- applied to something else (polymorphic values need to be applied to a
162 -- Type) will be inlined, and will eventually be applied to all their
163 -- arguments.
164 --
165 -- This is a tricky function, which is prone to create loops in the
166 -- transformations. To fix this, we make sure that no transformation will
167 -- create a new let binding with a function type. These other transformations
168 -- will just not work on those function-typed values at first, but the other
169 -- transformations (in particular β-reduction) should make sure that the type
170 -- of those values eventually becomes primitive.
171 inlinenonreptop :: Transform
172 inlinenonreptop = everywhere ("inlinenonrep", inlinebind ((Monad.liftM not) . isRepr . snd))
173
174 --------------------------------
175 -- Scrutinee simplification
176 --------------------------------
177 scrutsimpl,scrutsimpltop :: Transform
178 -- Don't touch scrutinees that are already simple
179 scrutsimpl expr@(Case (Var _) _ _ _) = return expr
180 -- Replace all other cases with a let that binds the scrutinee and a new
181 -- simple scrutinee, but not when the scrutinee is applicable (to prevent
182 -- loops with inlinefun, though I don't think a scrutinee can be
183 -- applicable...)
184 scrutsimpl (Case scrut b ty alts) | not $ is_applicable scrut = do
185   id <- mkInternalVar "scrut" (CoreUtils.exprType scrut)
186   change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
187 -- Leave all other expressions unchanged
188 scrutsimpl expr = return expr
189 -- Perform this transform everywhere
190 scrutsimpltop = everywhere ("scrutsimpl", scrutsimpl)
191
192 --------------------------------
193 -- Case binder wildening
194 --------------------------------
195 casewild, casewildtop :: Transform
196 casewild expr@(Case scrut b ty alts) = do
197   (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts
198   let bindings = concat bindingss
199   -- Replace the case with a let with bindings and a case
200   let newlet = (Let (Rec bindings) (Case scrut b ty alts'))
201   -- If there are no non-wild binders, or this case is already a simple
202   -- selector (i.e., a single alt with exactly one binding), already a simple
203   -- selector altan no bindings (i.e., no wild binders in the original case),
204   -- don't change anything, otherwise, replace the case.
205   if null bindings || length alts == 1 && length bindings == 1 then return expr else change newlet 
206   where
207   -- Generate a single wild binder, since they are all the same
208   wild = MkCore.mkWildBinder
209   -- Wilden the binders of one alt, producing a list of bindings as a
210   -- sideeffect.
211   doalt :: CoreAlt -> TransformMonad ([(CoreBndr, CoreExpr)], CoreAlt)
212   doalt (con, bndrs, expr) = do
213     bindings_maybe <- Monad.zipWithM mkextracts bndrs [0..]
214     let bindings = Maybe.catMaybes bindings_maybe
215     -- We replace the binders with wild binders only. We can leave expr
216     -- unchanged, since the new bindings bind the same vars as the original
217     -- did.
218     let newalt = (con, wildbndrs, expr)
219     return (bindings, newalt)
220     where
221       -- Make all binders wild
222       wildbndrs = map (\bndr -> MkCore.mkWildBinder (Id.idType bndr)) bndrs
223       -- A set of all the binders that are used by the expression
224       free_vars = CoreFVs.exprSomeFreeVars (`elem` bndrs) expr
225       -- Creates a case statement to retrieve the ith element from the scrutinee
226       -- and binds that to b.
227       mkextracts :: CoreBndr -> Int -> TransformMonad (Maybe (CoreBndr, CoreExpr))
228       mkextracts b i =
229         if not (VarSet.elemVarSet b free_vars) || Type.isFunTy (Id.idType b) 
230           -- Don't create extra bindings for binders that are already wild
231           -- (e.g. not in the free variables of expr, so unused), or for
232           -- binders that bind function types (to prevent loops with
233           -- inlinefun).
234           then return Nothing
235           else do
236             -- Create on new binder that will actually capture a value in this
237             -- case statement, and return it
238             let bty = (Id.idType b)
239             id <- mkInternalVar "sel" bty
240             let binders = take i wildbndrs ++ [id] ++ drop (i+1) wildbndrs
241             return $ Just (b, Case scrut b bty [(con, binders, Var id)])
242 -- Leave all other expressions unchanged
243 casewild expr = return expr
244 -- Perform this transform everywhere
245 casewildtop = everywhere ("casewild", casewild)
246
247 --------------------------------
248 -- Case value simplification
249 --------------------------------
250 casevalsimpl, casevalsimpltop :: Transform
251 casevalsimpl expr@(Case scrut b ty alts) = do
252   -- Try to simplify each alternative, resulting in an optional binding and a
253   -- new alternative.
254   (bindings_maybe, alts') <- (Monad.liftM unzip) $ mapM doalt alts
255   let bindings = Maybe.catMaybes bindings_maybe
256   -- Create a new let around the case, that binds of the cases values.
257   let newlet = Let (Rec bindings) (Case scrut b ty alts')
258   -- If there were no values that needed and allowed simplification, don't
259   -- change the case.
260   if null bindings then return expr else change newlet 
261   where
262     doalt :: CoreAlt -> TransformMonad (Maybe (CoreBndr, CoreExpr), CoreAlt)
263     -- Don't simplify values that are already simple
264     doalt alt@(con, bndrs, Var _) = return (Nothing, alt)
265     -- Simplify each alt by creating a new id, binding the case value to it and
266     -- replacing the case value with that id. Only do this when the case value
267     -- does not use any of the binders bound by this alternative, for that would
268     -- cause those binders to become unbound when moving the value outside of
269     -- the case statement. Also, don't create a binding for applicable
270     -- expressions, to prevent loops with inlinefun.
271     doalt (con, bndrs, expr) | (not usesvars) && (not $ is_applicable expr) = do
272       id <- mkInternalVar "caseval" (CoreUtils.exprType expr)
273       -- We don't flag a change here, since casevalsimpl will do that above
274       -- based on Just we return here.
275       return $ (Just (id, expr), (con, bndrs, Var id))
276       -- Find if any of the binders are used by expr
277       where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
278     -- Don't simplify anything else
279     doalt alt = return (Nothing, alt)
280 -- Leave all other expressions unchanged
281 casevalsimpl expr = return expr
282 -- Perform this transform everywhere
283 casevalsimpltop = everywhere ("casevalsimpl", casevalsimpl)
284
285 --------------------------------
286 -- Case removal
287 --------------------------------
288 -- Remove case statements that have only a single alternative and only wild
289 -- binders.
290 caseremove, caseremovetop :: Transform
291 -- Replace a useless case by the value of its single alternative
292 caseremove (Case scrut b ty [(con, bndrs, expr)]) | not usesvars = change expr
293     -- Find if any of the binders are used by expr
294     where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
295 -- Leave all other expressions unchanged
296 caseremove expr = return expr
297 -- Perform this transform everywhere
298 caseremovetop = everywhere ("caseremove", caseremove)
299
300 --------------------------------
301 -- Argument extraction
302 --------------------------------
303 -- Make sure that all arguments of a representable type are simple variables.
304 appsimpl, appsimpltop :: Transform
305 -- Simplify all representable arguments. Do this by introducing a new Let
306 -- that binds the argument and passing the new binder in the application.
307 appsimpl expr@(App f arg) = do
308   -- Check runtime representability
309   repr <- isRepr arg
310   local_var <- Trans.lift $ is_local_var arg
311   if repr && not local_var
312     then do -- Extract representable arguments
313       id <- mkInternalVar "arg" (CoreUtils.exprType arg)
314       change $ Let (Rec [(id, arg)]) (App f (Var id))
315     else -- Leave non-representable arguments unchanged
316       return expr
317 -- Leave all other expressions unchanged
318 appsimpl expr = return expr
319 -- Perform this transform everywhere
320 appsimpltop = everywhere ("appsimpl", appsimpl)
321
322 --------------------------------
323 -- Function-typed argument propagation
324 --------------------------------
325 -- Remove all applications to function-typed arguments, by duplication the
326 -- function called with the function-typed parameter replaced by the free
327 -- variables of the argument passed in.
328 argprop, argproptop :: Transform
329 -- Transform any application of a named function (i.e., skip applications of
330 -- lambda's). Also skip applications that have arguments with free type
331 -- variables, since we can't inline those.
332 argprop expr@(App _ _) | is_var fexpr = do
333   -- Find the body of the function called
334   body_maybe <- Trans.lift $ getGlobalBind f
335   case body_maybe of
336     Just body -> do
337       -- Process each of the arguments in turn
338       (args', changed) <- Writer.listen $ mapM doarg args
339       -- See if any of the arguments changed
340       case Monoid.getAny changed of
341         True -> do
342           let (newargs', newparams', oldargs) = unzip3 args'
343           let newargs = concat newargs'
344           let newparams = concat newparams'
345           -- Create a new body that consists of a lambda for all new arguments and
346           -- the old body applied to some arguments.
347           let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
348           -- Create a new function with the same name but a new body
349           newf <- mkFunction f newbody
350           -- Replace the original application with one of the new function to the
351           -- new arguments.
352           change $ MkCore.mkCoreApps (Var newf) newargs
353         False ->
354           -- Don't change the expression if none of the arguments changed
355           return expr
356       
357     -- If we don't have a body for the function called, leave it unchanged (it
358     -- should be a primitive function then).
359     Nothing -> return expr
360   where
361     -- Find the function called and the arguments
362     (fexpr, args) = collectArgs expr
363     Var f = fexpr
364
365     -- Process a single argument and return (args, bndrs, arg), where args are
366     -- the arguments to replace the given argument in the original
367     -- application, bndrs are the binders to include in the top-level lambda
368     -- in the new function body, and arg is the argument to apply to the old
369     -- function body.
370     doarg :: CoreExpr -> TransformMonad ([CoreExpr], [CoreBndr], CoreExpr)
371     doarg arg = do
372       repr <- isRepr arg
373       bndrs <- Trans.lift getGlobalBinders
374       let interesting var = Var.isLocalVar var && (not $ var `elem` bndrs)
375       if not repr && not (is_var arg && interesting (exprToVar arg)) && not (has_free_tyvars arg) 
376         then do
377           -- Propagate all complex arguments that are not representable, but not
378           -- arguments with free type variables (since those would require types
379           -- not known yet, which will always be known eventually).
380           -- Find interesting free variables, each of which should be passed to
381           -- the new function instead of the original function argument.
382           -- 
383           -- Interesting vars are those that are local, but not available from the
384           -- top level scope (functions from this module are defined as local, but
385           -- they're not local to this function, so we can freely move references
386           -- to them into another function).
387           let free_vars = VarSet.varSetElems $ CoreFVs.exprSomeFreeVars interesting arg
388           -- Mark the current expression as changed
389           setChanged
390           return (map Var free_vars, free_vars, arg)
391         else do
392           -- Representable types will not be propagated, and arguments with free
393           -- type variables will be propagated later.
394           -- TODO: preserve original naming?
395           id <- mkBinderFor arg "param"
396           -- Just pass the original argument to the new function, which binds it
397           -- to a new id and just pass that new id to the old function body.
398           return ([arg], [id], mkReferenceTo id) 
399 -- Leave all other expressions unchanged
400 argprop expr = return expr
401 -- Perform this transform everywhere
402 argproptop = everywhere ("argprop", argprop)
403
404 --------------------------------
405 -- Function-typed argument extraction
406 --------------------------------
407 -- This transform takes any function-typed argument that cannot be propagated
408 -- (because the function that is applied to it is a builtin function), and
409 -- puts it in a brand new top level binder. This allows us to for example
410 -- apply map to a lambda expression This will not conflict with inlinefun,
411 -- since that only inlines local let bindings, not top level bindings.
412 funextract, funextracttop :: Transform
413 funextract expr@(App _ _) | is_var fexpr = do
414   body_maybe <- Trans.lift $ getGlobalBind f
415   case body_maybe of
416     -- We don't have a function body for f, so we can perform this transform.
417     Nothing -> do
418       -- Find the new arguments
419       args' <- mapM doarg args
420       -- And update the arguments. We use return instead of changed, so the
421       -- changed flag doesn't get set if none of the args got changed.
422       return $ MkCore.mkCoreApps fexpr args'
423     -- We have a function body for f, leave this application to funprop
424     Just _ -> return expr
425   where
426     -- Find the function called and the arguments
427     (fexpr, args) = collectArgs expr
428     Var f = fexpr
429     -- Change any arguments that have a function type, but are not simple yet
430     -- (ie, a variable or application). This means to create a new function
431     -- for map (\f -> ...) b, but not for map (foo a) b.
432     --
433     -- We could use is_applicable here instead of is_fun, but I think
434     -- arguments to functions could only have forall typing when existential
435     -- typing is enabled. Not sure, though.
436     doarg arg | not (is_simple arg) && is_fun arg = do
437       -- Create a new top level binding that binds the argument. Its body will
438       -- be extended with lambda expressions, to take any free variables used
439       -- by the argument expression.
440       let free_vars = VarSet.varSetElems $ CoreFVs.exprFreeVars arg
441       let body = MkCore.mkCoreLams free_vars arg
442       id <- mkBinderFor body "fun"
443       Trans.lift $ addGlobalBind id body
444       -- Replace the argument with a reference to the new function, applied to
445       -- all vars it uses.
446       change $ MkCore.mkCoreApps (Var id) (map Var free_vars)
447     -- Leave all other arguments untouched
448     doarg arg = return arg
449
450 -- Leave all other expressions unchanged
451 funextract expr = return expr
452 -- Perform this transform everywhere
453 funextracttop = everywhere ("funextract", funextract)
454
455 --------------------------------
456 -- End of transformations
457 --------------------------------
458
459
460
461
462 -- What transforms to run?
463 transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinenonreptop, appsimpltop]
464
465 -- Turns the given bind into VHDL
466 normalizeModule ::
467   HscTypes.HscEnv
468   -> UniqSupply.UniqSupply -- ^ A UniqSupply we can use
469   -> [(CoreBndr, CoreExpr)]  -- ^ All bindings we know (i.e., in the current module)
470   -> [CoreBndr]  -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
471   -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
472   -> ([(CoreBndr, CoreExpr)], TypeState) -- ^ The resulting VHDL
473
474 normalizeModule env uniqsupply bindings generate_for statefuls = runTransformSession env uniqsupply $ do
475   -- Put all the bindings in this module in the tsBindings map
476   putA tsBindings (Map.fromList bindings)
477   -- (Recursively) normalize each of the requested bindings
478   mapM normalizeBind generate_for
479   -- Get all initial bindings and the ones we produced
480   bindings_map <- getA tsBindings
481   let bindings = Map.assocs bindings_map
482   normalized_bindings <- getA tsNormalized
483   typestate <- getA tsType
484   -- But return only the normalized bindings
485   return $ (filter ((flip VarSet.elemVarSet normalized_bindings) . fst) bindings, typestate)
486
487 normalizeBind :: CoreBndr -> TransformSession ()
488 normalizeBind bndr =
489   -- Don't normalize global variables, these should be either builtin
490   -- functions or data constructors.
491   Monad.when (Var.isLocalId bndr) $ do
492     -- Skip binders that have a polymorphic type, since it's impossible to
493     -- create polymorphic hardware.
494     if is_poly (Var bndr)
495       then
496         -- This should really only happen at the top level... TODO: Give
497         -- a different error if this happens down in the recursion.
498         error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
499       else do
500         normalized_funcs <- getA tsNormalized
501         -- See if this function was normalized already
502         if VarSet.elemVarSet bndr normalized_funcs
503           then
504             -- Yup, don't do it again
505             return ()
506           else do
507             -- Nope, note that it has been and do it.
508             modA tsNormalized (flip VarSet.extendVarSet bndr)
509             expr_maybe <- getGlobalBind bndr
510             case expr_maybe of 
511               Just expr -> do
512                 -- Introduce an empty Let at the top level, so there will always be
513                 -- a let in the expression (none of the transformations will remove
514                 -- the last let).
515                 let expr' = Let (Rec []) expr
516                 -- Normalize this expression
517                 trace ("Transforming " ++ (show bndr) ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
518                 expr' <- dotransforms transforms expr'
519                 trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
520                 -- And store the normalized version in the session
521                 modA tsBindings (Map.insert bndr expr')
522                 -- Find all vars used with a function type. All of these should be global
523                 -- binders (i.e., functions used), since any local binders with a function
524                 -- type should have been inlined already.
525                 bndrs <- getGlobalBinders
526                 let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> not (Id.isDictId v) && v `elem` bndrs) expr'
527                 let used_funcs = VarSet.varSetElems used_funcs_set
528                 -- Process each of the used functions recursively
529                 mapM normalizeBind used_funcs
530                 return ()
531               -- We don't have a value for this binder. This really shouldn't
532               -- happen for local id's...
533               Nothing -> error $ "\nNormalize.normalizeBind: No value found for binder " ++ pprString bndr ++ "? This should not happen!"