Quick hack implementation of FSVec literals, needs to be fixed
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
1 {-# LANGUAGE PackageImports #-}
2 --
3 -- Functions to bring a Core expression in normal form. This module provides a
4 -- top level function "normalize", and defines the actual transformation passes that
5 -- are performed.
6 --
7 module CLasH.Normalize (normalizeModule) where
8
9 -- Standard modules
10 import Debug.Trace
11 import qualified Maybe
12 import qualified "transformers" Control.Monad.Trans as Trans
13 import qualified Control.Monad as Monad
14 import qualified Control.Monad.Trans.Writer as Writer
15 import qualified Data.Map as Map
16 import qualified Data.Monoid as Monoid
17 import Data.Accessor
18
19 -- GHC API
20 import CoreSyn
21 import qualified UniqSupply
22 import qualified CoreUtils
23 import qualified Type
24 import qualified TcType
25 import qualified Id
26 import qualified Var
27 import qualified VarSet
28 import qualified NameSet
29 import qualified CoreFVs
30 import qualified CoreUtils
31 import qualified MkCore
32 import qualified HscTypes
33 import Outputable ( showSDoc, ppr, nest )
34
35 -- Local imports
36 import CLasH.Normalize.NormalizeTypes
37 import CLasH.Normalize.NormalizeTools
38 import CLasH.VHDL.VHDLTypes
39 import CLasH.Utils.Core.CoreTools
40 import CLasH.Utils.Pretty
41
42 --------------------------------
43 -- Start of transformations
44 --------------------------------
45
46 --------------------------------
47 -- η abstraction
48 --------------------------------
49 eta, etatop :: Transform
50 eta expr | is_fun expr && not (is_lam expr) = do
51   let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
52   id <- mkInternalVar "param" arg_ty
53   change (Lam id (App expr (Var id)))
54 -- Leave all other expressions unchanged
55 eta e = return e
56 etatop = notappargs ("eta", eta)
57
58 --------------------------------
59 -- β-reduction
60 --------------------------------
61 beta, betatop :: Transform
62 -- Substitute arg for x in expr
63 beta (App (Lam x expr) arg) = change $ substitute [(x, arg)] expr
64 -- Propagate the application into the let
65 beta (App (Let binds expr) arg) = change $ Let binds (App expr arg)
66 -- Propagate the application into each of the alternatives
67 beta (App (Case scrut b ty alts) arg) = change $ Case scrut b ty' alts'
68   where 
69     alts' = map (\(con, bndrs, expr) -> (con, bndrs, (App expr arg))) alts
70     ty' = CoreUtils.applyTypeToArg ty arg
71 -- Leave all other expressions unchanged
72 beta expr = return expr
73 -- Perform this transform everywhere
74 betatop = everywhere ("beta", beta)
75
76 --------------------------------
77 -- Cast propagation
78 --------------------------------
79 -- Try to move casts as much downward as possible.
80 castprop, castproptop :: Transform
81 castprop (Cast (Let binds expr) ty) = change $ Let binds (Cast expr ty)
82 castprop expr@(Cast (Case scrut b _ alts) ty) = change (Case scrut b ty alts')
83   where
84     alts' = map (\(con, bndrs, expr) -> (con, bndrs, (Cast expr ty))) alts
85 -- Leave all other expressions unchanged
86 castprop expr = return expr
87 -- Perform this transform everywhere
88 castproptop = everywhere ("castprop", castprop)
89
90 --------------------------------
91 -- let recursification
92 --------------------------------
93 letrec, letrectop :: Transform
94 letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
95 -- Leave all other expressions unchanged
96 letrec expr = return expr
97 -- Perform this transform everywhere
98 letrectop = everywhere ("letrec", letrec)
99
100 --------------------------------
101 -- let simplification
102 --------------------------------
103 letsimpl, letsimpltop :: Transform
104 -- Put the "in ..." value of a let in its own binding, but not when the
105 -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
106 letsimpl expr@(Let (Rec binds) res) = do
107   repr <- isRepr res
108   local_var <- Trans.lift $ is_local_var res
109   if not local_var && repr
110     then do
111       -- If the result is not a local var already (to prevent loops with
112       -- ourselves), extract it.
113       id <- mkInternalVar "foo" (CoreUtils.exprType res)
114       let bind = (id, res)
115       change $ Let (Rec (bind:binds)) (Var id)
116     else
117       -- If the result is already a local var, don't extract it.
118       return expr
119
120 -- Leave all other expressions unchanged
121 letsimpl expr = return expr
122 -- Perform this transform everywhere
123 letsimpltop = everywhere ("letsimpl", letsimpl)
124
125 --------------------------------
126 -- let flattening
127 --------------------------------
128 letflat, letflattop :: Transform
129 letflat (Let (Rec binds) expr) = do
130   -- Turn each binding into a list of bindings (possibly containing just one
131   -- element, of course)
132   bindss <- Monad.mapM flatbind binds
133   -- Concat all the bindings
134   let binds' = concat bindss
135   -- Return the new let. We don't use change here, since possibly nothing has
136   -- changed. If anything has changed, flatbind has already flagged that
137   -- change.
138   return $ Let (Rec binds') expr
139   where
140     -- Turns a binding of a let into a multiple bindings, or any other binding
141     -- into a list with just that binding
142     flatbind :: (CoreBndr, CoreExpr) -> TransformMonad [(CoreBndr, CoreExpr)]
143     flatbind (b, Let (Rec binds) expr) = change ((b, expr):binds)
144     flatbind (b, expr) = return [(b, expr)]
145 -- Leave all other expressions unchanged
146 letflat expr = return expr
147 -- Perform this transform everywhere
148 letflattop = everywhere ("letflat", letflat)
149
150 --------------------------------
151 -- Simple let binding removal
152 --------------------------------
153 -- Remove a = b bindings from let expressions everywhere
154 letremovetop :: Transform
155 letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
156
157 --------------------------------
158 -- Function inlining
159 --------------------------------
160 -- Remove a = B bindings, with B :: a -> b, or B :: forall x . T, from let
161 -- expressions everywhere. This means that any value that still needs to be
162 -- applied to something else (polymorphic values need to be applied to a
163 -- Type) will be inlined, and will eventually be applied to all their
164 -- arguments.
165 --
166 -- This is a tricky function, which is prone to create loops in the
167 -- transformations. To fix this, we make sure that no transformation will
168 -- create a new let binding with a function type. These other transformations
169 -- will just not work on those function-typed values at first, but the other
170 -- transformations (in particular β-reduction) should make sure that the type
171 -- of those values eventually becomes primitive.
172 inlinenonreptop :: Transform
173 inlinenonreptop = everywhere ("inlinenonrep", inlinebind ((Monad.liftM not) . isRepr . snd))
174
175 --------------------------------
176 -- Scrutinee simplification
177 --------------------------------
178 scrutsimpl,scrutsimpltop :: Transform
179 -- Don't touch scrutinees that are already simple
180 scrutsimpl expr@(Case (Var _) _ _ _) = return expr
181 -- Replace all other cases with a let that binds the scrutinee and a new
182 -- simple scrutinee, but not when the scrutinee is applicable (to prevent
183 -- loops with inlinefun, though I don't think a scrutinee can be
184 -- applicable...)
185 scrutsimpl (Case scrut b ty alts) | not $ is_applicable scrut = do
186   id <- mkInternalVar "scrut" (CoreUtils.exprType scrut)
187   change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
188 -- Leave all other expressions unchanged
189 scrutsimpl expr = return expr
190 -- Perform this transform everywhere
191 scrutsimpltop = everywhere ("scrutsimpl", scrutsimpl)
192
193 --------------------------------
194 -- Case binder wildening
195 --------------------------------
196 casewild, casewildtop :: Transform
197 casewild expr@(Case scrut b ty alts) = do
198   (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts
199   let bindings = concat bindingss
200   -- Replace the case with a let with bindings and a case
201   let newlet = (Let (Rec bindings) (Case scrut b ty alts'))
202   -- If there are no non-wild binders, or this case is already a simple
203   -- selector (i.e., a single alt with exactly one binding), already a simple
204   -- selector altan no bindings (i.e., no wild binders in the original case),
205   -- don't change anything, otherwise, replace the case.
206   if null bindings || length alts == 1 && length bindings == 1 then return expr else change newlet 
207   where
208   -- Generate a single wild binder, since they are all the same
209   wild = MkCore.mkWildBinder
210   -- Wilden the binders of one alt, producing a list of bindings as a
211   -- sideeffect.
212   doalt :: CoreAlt -> TransformMonad ([(CoreBndr, CoreExpr)], CoreAlt)
213   doalt (con, bndrs, expr) = do
214     bindings_maybe <- Monad.zipWithM mkextracts bndrs [0..]
215     let bindings = Maybe.catMaybes bindings_maybe
216     -- We replace the binders with wild binders only. We can leave expr
217     -- unchanged, since the new bindings bind the same vars as the original
218     -- did.
219     let newalt = (con, wildbndrs, expr)
220     return (bindings, newalt)
221     where
222       -- Make all binders wild
223       wildbndrs = map (\bndr -> MkCore.mkWildBinder (Id.idType bndr)) bndrs
224       -- A set of all the binders that are used by the expression
225       free_vars = CoreFVs.exprSomeFreeVars (`elem` bndrs) expr
226       -- Creates a case statement to retrieve the ith element from the scrutinee
227       -- and binds that to b.
228       mkextracts :: CoreBndr -> Int -> TransformMonad (Maybe (CoreBndr, CoreExpr))
229       mkextracts b i =
230         if not (VarSet.elemVarSet b free_vars) || Type.isFunTy (Id.idType b) 
231           -- Don't create extra bindings for binders that are already wild
232           -- (e.g. not in the free variables of expr, so unused), or for
233           -- binders that bind function types (to prevent loops with
234           -- inlinefun).
235           then return Nothing
236           else do
237             -- Create on new binder that will actually capture a value in this
238             -- case statement, and return it
239             let bty = (Id.idType b)
240             id <- mkInternalVar "sel" bty
241             let binders = take i wildbndrs ++ [id] ++ drop (i+1) wildbndrs
242             return $ Just (b, Case scrut b bty [(con, binders, Var id)])
243 -- Leave all other expressions unchanged
244 casewild expr = return expr
245 -- Perform this transform everywhere
246 casewildtop = everywhere ("casewild", casewild)
247
248 --------------------------------
249 -- Case value simplification
250 --------------------------------
251 casevalsimpl, casevalsimpltop :: Transform
252 casevalsimpl expr@(Case scrut b ty alts) = do
253   -- Try to simplify each alternative, resulting in an optional binding and a
254   -- new alternative.
255   (bindings_maybe, alts') <- (Monad.liftM unzip) $ mapM doalt alts
256   let bindings = Maybe.catMaybes bindings_maybe
257   -- Create a new let around the case, that binds of the cases values.
258   let newlet = Let (Rec bindings) (Case scrut b ty alts')
259   -- If there were no values that needed and allowed simplification, don't
260   -- change the case.
261   if null bindings then return expr else change newlet 
262   where
263     doalt :: CoreAlt -> TransformMonad (Maybe (CoreBndr, CoreExpr), CoreAlt)
264     -- Don't simplify values that are already simple
265     doalt alt@(con, bndrs, Var _) = return (Nothing, alt)
266     -- Simplify each alt by creating a new id, binding the case value to it and
267     -- replacing the case value with that id. Only do this when the case value
268     -- does not use any of the binders bound by this alternative, for that would
269     -- cause those binders to become unbound when moving the value outside of
270     -- the case statement. Also, don't create a binding for applicable
271     -- expressions, to prevent loops with inlinefun.
272     doalt (con, bndrs, expr) | (not usesvars) && (not $ is_applicable expr) = do
273       id <- mkInternalVar "caseval" (CoreUtils.exprType expr)
274       -- We don't flag a change here, since casevalsimpl will do that above
275       -- based on Just we return here.
276       return $ (Just (id, expr), (con, bndrs, Var id))
277       -- Find if any of the binders are used by expr
278       where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
279     -- Don't simplify anything else
280     doalt alt = return (Nothing, alt)
281 -- Leave all other expressions unchanged
282 casevalsimpl expr = return expr
283 -- Perform this transform everywhere
284 casevalsimpltop = everywhere ("casevalsimpl", casevalsimpl)
285
286 --------------------------------
287 -- Case removal
288 --------------------------------
289 -- Remove case statements that have only a single alternative and only wild
290 -- binders.
291 caseremove, caseremovetop :: Transform
292 -- Replace a useless case by the value of its single alternative
293 caseremove (Case scrut b ty [(con, bndrs, expr)]) | not usesvars = change expr
294     -- Find if any of the binders are used by expr
295     where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
296 -- Leave all other expressions unchanged
297 caseremove expr = return expr
298 -- Perform this transform everywhere
299 caseremovetop = everywhere ("caseremove", caseremove)
300
301 --------------------------------
302 -- Argument extraction
303 --------------------------------
304 -- Make sure that all arguments of a representable type are simple variables.
305 appsimpl, appsimpltop :: Transform
306 -- Simplify all representable arguments. Do this by introducing a new Let
307 -- that binds the argument and passing the new binder in the application.
308 appsimpl expr@(App f arg) = do
309   -- Check runtime representability
310   repr <- isRepr arg
311   local_var <- Trans.lift $ is_local_var arg
312   if repr && not local_var
313     then do -- Extract representable arguments
314       id <- mkInternalVar "arg" (CoreUtils.exprType arg)
315       change $ Let (Rec [(id, arg)]) (App f (Var id))
316     else -- Leave non-representable arguments unchanged
317       return expr
318 -- Leave all other expressions unchanged
319 appsimpl expr = return expr
320 -- Perform this transform everywhere
321 appsimpltop = everywhere ("appsimpl", appsimpl)
322
323 --------------------------------
324 -- Function-typed argument propagation
325 --------------------------------
326 -- Remove all applications to function-typed arguments, by duplication the
327 -- function called with the function-typed parameter replaced by the free
328 -- variables of the argument passed in.
329 argprop, argproptop :: Transform
330 -- Transform any application of a named function (i.e., skip applications of
331 -- lambda's). Also skip applications that have arguments with free type
332 -- variables, since we can't inline those.
333 argprop expr@(App _ _) | is_var fexpr = do
334   -- Find the body of the function called
335   body_maybe <- Trans.lift $ getGlobalBind f
336   case body_maybe of
337     Just body -> do
338       -- Process each of the arguments in turn
339       (args', changed) <- Writer.listen $ mapM doarg args
340       -- See if any of the arguments changed
341       case Monoid.getAny changed of
342         True -> do
343           let (newargs', newparams', oldargs) = unzip3 args'
344           let newargs = concat newargs'
345           let newparams = concat newparams'
346           -- Create a new body that consists of a lambda for all new arguments and
347           -- the old body applied to some arguments.
348           let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
349           -- Create a new function with the same name but a new body
350           newf <- mkFunction f newbody
351           -- Replace the original application with one of the new function to the
352           -- new arguments.
353           change $ MkCore.mkCoreApps (Var newf) newargs
354         False ->
355           -- Don't change the expression if none of the arguments changed
356           return expr
357       
358     -- If we don't have a body for the function called, leave it unchanged (it
359     -- should be a primitive function then).
360     Nothing -> return expr
361   where
362     -- Find the function called and the arguments
363     (fexpr, args) = collectArgs expr
364     Var f = fexpr
365
366     -- Process a single argument and return (args, bndrs, arg), where args are
367     -- the arguments to replace the given argument in the original
368     -- application, bndrs are the binders to include in the top-level lambda
369     -- in the new function body, and arg is the argument to apply to the old
370     -- function body.
371     doarg :: CoreExpr -> TransformMonad ([CoreExpr], [CoreBndr], CoreExpr)
372     doarg arg = do
373       repr <- isRepr arg
374       bndrs <- Trans.lift getGlobalBinders
375       let interesting var = Var.isLocalVar var && (not $ var `elem` bndrs)
376       if not repr && not (is_var arg && interesting (exprToVar arg)) && not (has_free_tyvars arg) 
377         then do
378           -- Propagate all complex arguments that are not representable, but not
379           -- arguments with free type variables (since those would require types
380           -- not known yet, which will always be known eventually).
381           -- Find interesting free variables, each of which should be passed to
382           -- the new function instead of the original function argument.
383           -- 
384           -- Interesting vars are those that are local, but not available from the
385           -- top level scope (functions from this module are defined as local, but
386           -- they're not local to this function, so we can freely move references
387           -- to them into another function).
388           let free_vars = VarSet.varSetElems $ CoreFVs.exprSomeFreeVars interesting arg
389           -- Mark the current expression as changed
390           setChanged
391           return (map Var free_vars, free_vars, arg)
392         else do
393           -- Representable types will not be propagated, and arguments with free
394           -- type variables will be propagated later.
395           -- TODO: preserve original naming?
396           id <- mkBinderFor arg "param"
397           -- Just pass the original argument to the new function, which binds it
398           -- to a new id and just pass that new id to the old function body.
399           return ([arg], [id], mkReferenceTo id) 
400 -- Leave all other expressions unchanged
401 argprop expr = return expr
402 -- Perform this transform everywhere
403 argproptop = everywhere ("argprop", argprop)
404
405 --------------------------------
406 -- Function-typed argument extraction
407 --------------------------------
408 -- This transform takes any function-typed argument that cannot be propagated
409 -- (because the function that is applied to it is a builtin function), and
410 -- puts it in a brand new top level binder. This allows us to for example
411 -- apply map to a lambda expression This will not conflict with inlinefun,
412 -- since that only inlines local let bindings, not top level bindings.
413 funextract, funextracttop :: Transform
414 funextract expr@(App _ _) | is_var fexpr = do
415   body_maybe <- Trans.lift $ getGlobalBind f
416   case body_maybe of
417     -- We don't have a function body for f, so we can perform this transform.
418     Nothing -> do
419       -- Find the new arguments
420       args' <- mapM doarg args
421       -- And update the arguments. We use return instead of changed, so the
422       -- changed flag doesn't get set if none of the args got changed.
423       return $ MkCore.mkCoreApps fexpr args'
424     -- We have a function body for f, leave this application to funprop
425     Just _ -> return expr
426   where
427     -- Find the function called and the arguments
428     (fexpr, args) = collectArgs expr
429     Var f = fexpr
430     -- Change any arguments that have a function type, but are not simple yet
431     -- (ie, a variable or application). This means to create a new function
432     -- for map (\f -> ...) b, but not for map (foo a) b.
433     --
434     -- We could use is_applicable here instead of is_fun, but I think
435     -- arguments to functions could only have forall typing when existential
436     -- typing is enabled. Not sure, though.
437     doarg arg | not (is_simple arg) && is_fun arg = do
438       -- Create a new top level binding that binds the argument. Its body will
439       -- be extended with lambda expressions, to take any free variables used
440       -- by the argument expression.
441       let free_vars = VarSet.varSetElems $ CoreFVs.exprFreeVars arg
442       let body = MkCore.mkCoreLams free_vars arg
443       id <- mkBinderFor body "fun"
444       Trans.lift $ addGlobalBind id body
445       -- Replace the argument with a reference to the new function, applied to
446       -- all vars it uses.
447       change $ MkCore.mkCoreApps (Var id) (map Var free_vars)
448     -- Leave all other arguments untouched
449     doarg arg = return arg
450
451 -- Leave all other expressions unchanged
452 funextract expr = return expr
453 -- Perform this transform everywhere
454 funextracttop = everywhere ("funextract", funextract)
455
456 --------------------------------
457 -- End of transformations
458 --------------------------------
459
460
461
462
463 -- What transforms to run?
464 transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinenonreptop, appsimpltop]
465
466 -- Turns the given bind into VHDL
467 normalizeModule ::
468   HscTypes.HscEnv
469   -> UniqSupply.UniqSupply -- ^ A UniqSupply we can use
470   -> [(CoreBndr, CoreExpr)]  -- ^ All bindings we know (i.e., in the current module)
471   -> [CoreExpr]
472   -> [CoreBndr]  -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
473   -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
474   -> ([(CoreBndr, CoreExpr)], [(CoreBndr, CoreExpr)], TypeState) -- ^ The resulting VHDL
475
476 normalizeModule env uniqsupply bindings testexprs generate_for statefuls = runTransformSession env uniqsupply $ do
477   testbinds <- mapM (\x -> do { v <- mkBinderFor' x "test" ; return (v,x) } ) testexprs
478   let testbinders = (map fst testbinds)
479   -- Put all the bindings in this module in the tsBindings map
480   putA tsBindings (Map.fromList (bindings ++ testbinds))
481   -- (Recursively) normalize each of the requested bindings
482   mapM normalizeBind (generate_for ++ testbinders)
483   -- Get all initial bindings and the ones we produced
484   bindings_map <- getA tsBindings
485   let bindings = Map.assocs bindings_map
486   normalized_binders' <- getA tsNormalized
487   let normalized_binders = VarSet.delVarSetList normalized_binders' testbinders
488   let ret_testbinds = zip testbinders (Maybe.catMaybes $ map (\x -> lookup x bindings) testbinders)
489   let ret_binds = filter ((`VarSet.elemVarSet` normalized_binders) . fst) bindings
490   typestate <- getA tsType
491   -- But return only the normalized bindings
492   return $ (ret_binds, ret_testbinds, typestate)
493
494 normalizeBind :: CoreBndr -> TransformSession ()
495 normalizeBind bndr =
496   -- Don't normalize global variables, these should be either builtin
497   -- functions or data constructors.
498   Monad.when (Var.isLocalId bndr) $ do
499     -- Skip binders that have a polymorphic type, since it's impossible to
500     -- create polymorphic hardware.
501     if is_poly (Var bndr)
502       then
503         -- This should really only happen at the top level... TODO: Give
504         -- a different error if this happens down in the recursion.
505         error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
506       else do
507         normalized_funcs <- getA tsNormalized
508         -- See if this function was normalized already
509         if VarSet.elemVarSet bndr normalized_funcs
510           then
511             -- Yup, don't do it again
512             return ()
513           else do
514             -- Nope, note that it has been and do it.
515             modA tsNormalized (flip VarSet.extendVarSet bndr)
516             expr_maybe <- getGlobalBind bndr
517             case expr_maybe of 
518               Just expr -> do
519                 -- Introduce an empty Let at the top level, so there will always be
520                 -- a let in the expression (none of the transformations will remove
521                 -- the last let).
522                 let expr' = Let (Rec []) expr
523                 -- Normalize this expression
524                 trace ("Transforming " ++ (show bndr) ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
525                 expr' <- dotransforms transforms expr'
526                 trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
527                 -- And store the normalized version in the session
528                 modA tsBindings (Map.insert bndr expr')
529                 -- Find all vars used with a function type. All of these should be global
530                 -- binders (i.e., functions used), since any local binders with a function
531                 -- type should have been inlined already.
532                 bndrs <- getGlobalBinders
533                 let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> not (Id.isDictId v) && v `elem` bndrs) expr'
534                 let used_funcs = VarSet.varSetElems used_funcs_set
535                 -- Process each of the used functions recursively
536                 mapM normalizeBind used_funcs
537                 return ()
538               -- We don't have a value for this binder. This really shouldn't
539               -- happen for local id's...
540               Nothing -> error $ "\nNormalize.normalizeBind: No value found for binder " ++ pprString bndr ++ "? This should not happen!"