2 -- This module provides functions for program transformations.
4 module CLasH.Normalize.NormalizeTools where
7 import qualified Data.Monoid as Monoid
8 import qualified Data.Either as Either
9 import qualified Control.Monad as Monad
10 import qualified Control.Monad.Trans.Writer as Writer
11 import qualified Control.Monad.Trans.Class as Trans
12 import qualified Data.Accessor.Monad.Trans.State as MonadState
18 import qualified CoreSubst
20 import qualified CoreUtils
21 import Outputable ( showSDoc, ppr, nest )
24 import CLasH.Normalize.NormalizeTypes
25 import CLasH.Translator.TranslatorTypes
26 import CLasH.VHDL.Constants (builtinIds)
28 import qualified CLasH.Utils.Core.CoreTools as CoreTools
29 import qualified CLasH.VHDL.VHDLTools as VHDLTools
31 -- Apply the given transformation to all expressions in the given expression,
32 -- including the expression itself.
33 everywhere :: Transform -> Transform
34 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
37 NormDbgNone -- ^ No debugging
38 | NormDbgFinal -- ^ Print functions before / after normalization
39 | NormDbgApplied -- ^ Print expressions before / after applying transformations
40 | NormDbgAll -- ^ Print expressions when a transformation does not apply
42 normalize_debug = NormDbgFinal
44 -- Applies a transform, optionally showing some debug output.
45 apply :: (String, Transform) -> Transform
46 apply (name, trans) ctx expr = do
47 -- Apply the transformation and find out if it changed anything
48 (expr', any_changed) <- Writer.listen $ trans ctx expr
49 let changed = Monoid.getAny any_changed
50 -- If it changed, increase the transformation counter
51 Monad.when changed $ Trans.lift (MonadState.modify tsTransformCounter (+1))
52 -- Prepare some debug strings
53 let before = showSDoc (nest 4 $ ppr expr) ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr) ++ "\n"
54 let context = "Context: " ++ show ctx ++ "\n"
55 let after = showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
56 traceIf (normalize_debug >= NormDbgApplied && changed) ("Changes when applying transform " ++ name ++ " to:\n" ++ before ++ context ++ "Result:\n" ++ after) $
57 traceIf (normalize_debug >= NormDbgAll && not changed) ("No changes when applying transform " ++ name ++ " to:\n" ++ before ++ context) $
60 -- Apply the first transformation, followed by the second transformation, and
61 -- keep applying both for as long as expression still changes.
62 applyboth :: Transform -> Transform -> Transform
63 applyboth first second context expr = do
65 expr' <- first context expr
67 (expr'', changed) <- Writer.listen $ second context expr'
68 if Monoid.getAny $ changed
70 applyboth first second context expr''
74 -- Apply the given transformation to all direct subexpressions (only), not the
76 subeverywhere :: Transform -> Transform
77 subeverywhere trans c (App a b) = do
78 a' <- trans (AppFirst:c) a
79 b' <- trans (AppSecond:c) b
82 subeverywhere trans c (Let (NonRec b bexpr) expr) = do
83 -- In the binding of a non-recursive let binding, no extra binders are
85 bexpr' <- trans (LetBinding []:c) bexpr
86 -- In the body of a non-recursive let binding, the bound binder is in
88 expr' <- trans ((LetBody [b]):c) expr
89 return $ Let (NonRec b bexpr') expr'
91 subeverywhere trans c (Let (Rec binds) expr) = do
92 -- In the body of a recursive let, all binders are in scope
93 expr' <- trans ((LetBody bndrs):c) expr
94 binds' <- mapM transbind binds
95 return $ Let (Rec binds') expr'
98 transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
100 -- In the bindings of a recursive let, all binders are in scope
101 e' <- trans ((LetBinding bndrs):c) e
104 subeverywhere trans c (Lam x expr) = do
105 -- In the body of a lambda, the bound binder is in scope.
106 expr' <- trans ((LambdaBody x):c) expr
109 subeverywhere trans c (Case scrut b t alts) = do
110 scrut' <- trans (Other:c) scrut
111 alts' <- mapM transalt alts
112 return $ Case scrut' b t alts'
114 transalt :: CoreAlt -> TransformMonad CoreAlt
115 transalt (con, binders, expr) = do
116 expr' <- trans (Other:c) expr
117 return (con, binders, expr')
119 subeverywhere trans c (Var x) = return $ Var x
120 subeverywhere trans c (Lit x) = return $ Lit x
121 subeverywhere trans c (Type x) = return $ Type x
123 subeverywhere trans c (Cast expr ty) = do
124 expr' <- trans (Other:c) expr
125 return $ Cast expr' ty
127 subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
129 -- Runs each of the transforms repeatedly inside the State monad.
130 dotransforms :: [(String, Transform)] -> CoreExpr -> TranslatorSession CoreExpr
131 dotransforms transs expr = do
132 (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> everywhere (apply trans) [] e) expr transs
133 if Monoid.getAny changed then dotransforms transs expr' else return expr'
135 -- Inline all let bindings that satisfy the given condition
136 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
137 inlinebind condition context expr@(Let (Rec binds) res) = do
138 -- Find all bindings that adhere to the condition
139 res_eithers <- mapM docond binds
140 case Either.partitionEithers res_eithers of
141 -- No replaces? No change
142 ([], _) -> return expr
143 (replace, others) -> do
144 -- Substitute the to be replaced binders with their expression
145 newexpr <- do_substitute replace (Let (Rec others) res)
148 -- Apply the condition to a let binding and return an Either
149 -- depending on whether it needs to be inlined or not.
150 docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
153 return $ case res of True -> Left b; False -> Right b
155 -- Apply the given list of substitutions to the the given expression
156 do_substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> TransformMonad CoreExpr
157 do_substitute [] expr = return expr
158 do_substitute ((bndr, val):reps) expr = do
159 -- Perform this substitution in the expression
160 expr' <- substitute_clone bndr val context expr
161 -- And in the substitution values we will be using next
162 reps' <- mapM (subs_bind bndr val) reps
163 -- And then perform the remaining substitutions
164 do_substitute reps' expr'
166 -- All binders bound in the transformed recursive let
167 bndrs = map fst binds
169 -- Replace the given binder with the given expression in the
170 -- expression oft the given let binding
171 subs_bind :: CoreBndr -> CoreExpr -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
172 subs_bind bndr expr (b, v) = do
173 v' <- substitute_clone bndr expr ((LetBinding bndrs):context) v
177 -- Leave all other expressions unchanged
178 inlinebind _ context expr = return expr
180 -- Sets the changed flag in the TransformMonad, to signify that some
181 -- transform has changed the result
182 setChanged :: TransformMonad ()
183 setChanged = Writer.tell (Monoid.Any True)
185 -- Sets the changed flag and returns the given value.
186 change :: a -> TransformMonad a
191 -- Returns the given value and sets the changed flag if the bool given is
192 -- True. Note that this will not unset the changed flag if the bool is False.
193 changeif :: Bool -> a -> TransformMonad a
194 changeif True val = change val
195 changeif False val = return val
197 -- | Creates a transformation that substitutes the given binder with the given
198 -- expression (This can be a type variable, replace by a Type expression).
199 -- Does not set the changed flag.
200 substitute :: CoreBndr -> CoreExpr -> Transform
201 -- Use CoreSubst to subst a type var in an expression
202 substitute find repl context expr = do
203 let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
204 return $ CoreSubst.substExpr subst expr
206 -- | Creates a transformation that substitutes the given binder with the given
207 -- expression. This does only work for value expressions! All binders in the
208 -- expression are cloned before the replacement, to guarantee uniqueness.
209 substitute_clone :: CoreBndr -> CoreExpr -> Transform
210 -- If we see the var to find, replace it by a uniqued version of repl
211 substitute_clone find repl context (Var var) | find == var = do
212 repl' <- Trans.lift $ CoreTools.genUniques repl
215 -- For all other expressions, just look in subexpressions
216 substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr
218 -- Is the given expression representable at runtime, based on the type?
219 isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
220 isRepr tything = Trans.lift (isRepr' tything)
222 isRepr' :: (CoreTools.TypedThing t) => t -> TranslatorSession Bool
223 isRepr' tything = case CoreTools.getType tything of
224 Nothing -> return False
225 Just ty -> MonadState.lift tsType $ VHDLTools.isReprType ty
227 is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
228 is_local_var (CoreSyn.Var v) = do
229 bndrs <- getGlobalBinders
230 -- A datacon id is not a global binder, but not a local variable
232 let is_dc = Id.isDataConWorkId v
233 return $ not is_dc && v `notElem` bndrs
234 is_local_var _ = return False
236 -- Is the given binder defined by the user?
237 isUserDefined :: CoreSyn.CoreBndr -> Bool
238 -- System names are certain to not be user defined
239 isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
240 -- Builtin functions are usually not user-defined either (and would
241 -- break currently if they are...)
242 isUserDefined bndr = str `notElem` builtinIds
244 str = Name.getOccString bndr
246 -- | Is the given binder normalizable? This means that its type signature can be
247 -- represented in hardware, which should (?) guarantee that it can be made
248 -- into hardware. This checks whether all the arguments and (optionally)
249 -- the return value are
252 Bool -- ^ Allow the result to be unrepresentable?
253 -> CoreBndr -- ^ The binder to check
254 -> TranslatorSession Bool -- ^ Is it normalizeable?
255 isNormalizeable result_nonrep bndr = do
256 let ty = Id.idType bndr
257 let (arg_tys, res_ty) = Type.splitFunTys ty
258 let check_tys = if result_nonrep then arg_tys else (res_ty:arg_tys)
259 andM $ mapM isRepr' check_tys