Store which binders become in scope in the CoreContext.
[matthijs/master-project/cλash.git] / clash / CLasH / Normalize / NormalizeTools.hs
1 -- 
2 -- This module provides functions for program transformations.
3 --
4 module CLasH.Normalize.NormalizeTools where
5
6 -- Standard modules
7 import qualified Data.Monoid as Monoid
8 import qualified Data.Either as Either
9 import qualified Control.Monad as Monad
10 import qualified Control.Monad.Trans.Writer as Writer
11 import qualified Control.Monad.Trans.Class as Trans
12 import qualified Data.Accessor.Monad.Trans.State as MonadState
13
14 -- GHC API
15 import CoreSyn
16 import qualified Name
17 import qualified Id
18 import qualified CoreSubst
19 import qualified Type
20 import qualified CoreUtils
21 import Outputable ( showSDoc, ppr, nest )
22
23 -- Local imports
24 import CLasH.Normalize.NormalizeTypes
25 import CLasH.Translator.TranslatorTypes
26 import CLasH.VHDL.Constants (builtinIds)
27 import CLasH.Utils
28 import qualified CLasH.Utils.Core.CoreTools as CoreTools
29 import qualified CLasH.VHDL.VHDLTools as VHDLTools
30
31 -- Apply the given transformation to all expressions in the given expression,
32 -- including the expression itself.
33 everywhere :: Transform -> Transform
34 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
35
36 data NormDbgLevel = 
37     NormDbgNone         -- ^ No debugging
38   | NormDbgFinal        -- ^ Print functions before / after normalization
39   | NormDbgApplied      -- ^ Print expressions before / after applying transformations
40   | NormDbgAll          -- ^ Print expressions when a transformation does not apply
41   deriving (Eq, Ord)
42 normalize_debug = NormDbgFinal
43
44 -- Applies a transform, optionally showing some debug output.
45 apply :: (String, Transform) -> Transform
46 apply (name, trans) ctx expr =  do
47     -- Apply the transformation and find out if it changed anything
48     (expr', any_changed) <- Writer.listen $ trans ctx expr
49     let changed = Monoid.getAny any_changed
50     -- If it changed, increase the transformation counter 
51     Monad.when changed $ Trans.lift (MonadState.modify tsTransformCounter (+1))
52     -- Prepare some debug strings
53     let before = showSDoc (nest 4 $ ppr expr) ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr) ++ "\n"
54     let context = "Context: " ++ show ctx ++ "\n"
55     let after  = showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
56     traceIf (normalize_debug >= NormDbgApplied && changed) ("Changes when applying transform " ++ name ++ " to:\n" ++ before ++ context ++ "Result:\n" ++ after) $ 
57      traceIf (normalize_debug >= NormDbgAll && not changed) ("No changes when applying transform " ++ name ++ " to:\n" ++ before  ++ context) $
58      return expr'
59
60 -- Apply the first transformation, followed by the second transformation, and
61 -- keep applying both for as long as expression still changes.
62 applyboth :: Transform -> Transform -> Transform
63 applyboth first second context expr = do
64   -- Apply the first
65   expr' <- first context expr
66   -- Apply the second
67   (expr'', changed) <- Writer.listen $ second context expr'
68   if Monoid.getAny $ changed
69     then
70       applyboth first second context expr'' 
71     else 
72       return expr''
73
74 -- Apply the given transformation to all direct subexpressions (only), not the
75 -- expression itself.
76 subeverywhere :: Transform -> Transform
77 subeverywhere trans c (App a b) = do
78   a' <- trans (AppFirst:c) a
79   b' <- trans (AppSecond:c) b
80   return $ App a' b'
81
82 subeverywhere trans c (Let (NonRec b bexpr) expr) = do
83   -- In the binding of a non-recursive let binding, no extra binders are
84   -- in scope.
85   bexpr' <- trans (LetBinding []:c) bexpr
86   -- In the body of a non-recursive let binding, the bound binder is in
87   -- scope.
88   expr' <- trans ((LetBody [b]):c) expr
89   return $ Let (NonRec b bexpr') expr'
90
91 subeverywhere trans c (Let (Rec binds) expr) = do
92   -- In the body of a recursive let, all binders are in scope
93   expr' <- trans ((LetBody bndrs):c) expr
94   binds' <- mapM transbind binds
95   return $ Let (Rec binds') expr'
96   where
97     bndrs = map fst binds
98     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
99     transbind (b, e) = do
100       -- In the bindings of a recursive let, all binders are in scope
101       e' <- trans ((LetBinding bndrs):c) e
102       return (b, e')
103
104 subeverywhere trans c (Lam x expr) = do
105   -- In the body of a lambda, the bound binder is in scope.
106   expr' <- trans ((LambdaBody x):c) expr
107   return $ Lam x expr'
108
109 subeverywhere trans c (Case scrut b t alts) = do
110   scrut' <- trans (Other:c) scrut
111   alts' <- mapM transalt alts
112   return $ Case scrut' b t alts'
113   where
114     transalt :: CoreAlt -> TransformMonad CoreAlt
115     transalt (con, binders, expr) = do
116       expr' <- trans (Other:c) expr
117       return (con, binders, expr')
118
119 subeverywhere trans c (Var x) = return $ Var x
120 subeverywhere trans c (Lit x) = return $ Lit x
121 subeverywhere trans c (Type x) = return $ Type x
122
123 subeverywhere trans c (Cast expr ty) = do
124   expr' <- trans (Other:c) expr
125   return $ Cast expr' ty
126
127 subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
128
129 -- Runs each of the transforms repeatedly inside the State monad.
130 dotransforms :: [(String, Transform)] -> CoreExpr -> TranslatorSession CoreExpr
131 dotransforms transs expr = do
132   (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> everywhere (apply trans) [] e) expr transs
133   if Monoid.getAny changed then dotransforms transs expr' else return expr'
134
135 -- Inline all let bindings that satisfy the given condition
136 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
137 inlinebind condition context expr@(Let (Rec binds) res) = do
138     -- Find all bindings that adhere to the condition
139     res_eithers <- mapM docond binds
140     case Either.partitionEithers res_eithers of
141       -- No replaces? No change
142       ([], _) -> return expr
143       (replace, others) -> do
144         -- Substitute the to be replaced binders with their expression
145         newexpr <- do_substitute replace (Let (Rec others) res)
146         change newexpr
147   where 
148     -- Apply the condition to a let binding and return an Either
149     -- depending on whether it needs to be inlined or not.
150     docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
151     docond b = do
152       res <- condition b
153       return $ case res of True -> Left b; False -> Right b
154
155     -- Apply the given list of substitutions to the the given expression
156     do_substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> TransformMonad CoreExpr
157     do_substitute [] expr = return expr
158     do_substitute ((bndr, val):reps) expr = do
159       -- Perform this substitution in the expression
160       expr' <- substitute_clone bndr val context expr
161       -- And in the substitution values we will be using next
162       reps' <- mapM (subs_bind bndr val) reps
163       -- And then perform the remaining substitutions
164       do_substitute reps' expr'
165
166     -- All binders bound in the transformed recursive let
167     bndrs = map fst binds
168    
169     -- Replace the given binder with the given expression in the
170     -- expression oft the given let binding
171     subs_bind :: CoreBndr -> CoreExpr -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
172     subs_bind bndr expr (b, v) = do
173       v' <- substitute_clone  bndr expr ((LetBinding bndrs):context) v
174       return (b, v')
175
176
177 -- Leave all other expressions unchanged
178 inlinebind _ context expr = return expr
179
180 -- Sets the changed flag in the TransformMonad, to signify that some
181 -- transform has changed the result
182 setChanged :: TransformMonad ()
183 setChanged = Writer.tell (Monoid.Any True)
184
185 -- Sets the changed flag and returns the given value.
186 change :: a -> TransformMonad a
187 change val = do
188   setChanged
189   return val
190
191 -- Returns the given value and sets the changed flag if the bool given is
192 -- True. Note that this will not unset the changed flag if the bool is False.
193 changeif :: Bool -> a -> TransformMonad a
194 changeif True val = change val
195 changeif False val = return val
196
197 -- | Creates a transformation that substitutes the given binder with the given
198 -- expression (This can be a type variable, replace by a Type expression).
199 -- Does not set the changed flag.
200 substitute :: CoreBndr -> CoreExpr -> Transform
201 -- Use CoreSubst to subst a type var in an expression
202 substitute find repl context expr = do
203   let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
204   return $ CoreSubst.substExpr subst expr 
205
206 -- | Creates a transformation that substitutes the given binder with the given
207 -- expression. This does only work for value expressions! All binders in the
208 -- expression are cloned before the replacement, to guarantee uniqueness.
209 substitute_clone :: CoreBndr -> CoreExpr -> Transform
210 -- If we see the var to find, replace it by a uniqued version of repl
211 substitute_clone find repl context (Var var) | find == var = do
212   repl' <- Trans.lift $ CoreTools.genUniques repl
213   change repl'
214
215 -- For all other expressions, just look in subexpressions
216 substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr
217
218 -- Is the given expression representable at runtime, based on the type?
219 isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
220 isRepr tything = Trans.lift (isRepr' tything)
221
222 isRepr' :: (CoreTools.TypedThing t) => t -> TranslatorSession Bool
223 isRepr' tything = case CoreTools.getType tything of
224   Nothing -> return False
225   Just ty -> MonadState.lift tsType $ VHDLTools.isReprType ty 
226
227 is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
228 is_local_var (CoreSyn.Var v) = do
229   bndrs <- getGlobalBinders
230   -- A datacon id is not a global binder, but not a local variable
231   -- either.
232   let is_dc = Id.isDataConWorkId v
233   return $ not is_dc && v `notElem` bndrs
234 is_local_var _ = return False
235
236 -- Is the given binder defined by the user?
237 isUserDefined :: CoreSyn.CoreBndr -> Bool
238 -- System names are certain to not be user defined
239 isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
240 -- Builtin functions are usually not user-defined either (and would
241 -- break currently if they are...)
242 isUserDefined bndr = str `notElem` builtinIds
243   where
244     str = Name.getOccString bndr
245
246 -- | Is the given binder normalizable? This means that its type signature can be
247 -- represented in hardware, which should (?) guarantee that it can be made
248 -- into hardware. This checks whether all the arguments and (optionally)
249 -- the return value are
250 -- representable.
251 isNormalizeable :: 
252   Bool -- ^ Allow the result to be unrepresentable?
253   -> CoreBndr  -- ^ The binder to check
254   -> TranslatorSession Bool  -- ^ Is it normalizeable?
255 isNormalizeable result_nonrep bndr = do
256   let ty = Id.idType bndr
257   let (arg_tys, res_ty) = Type.splitFunTys ty
258   let check_tys = if result_nonrep then arg_tys else (res_ty:arg_tys) 
259   andM $ mapM isRepr' check_tys