1 {-# LANGUAGE PackageImports #-}
3 -- This module provides functions for program transformations.
5 module CLasH.Normalize.NormalizeTools where
8 import qualified Data.Monoid as Monoid
9 import qualified Control.Monad as Monad
10 import qualified Control.Monad.Trans.Writer as Writer
11 import qualified "transformers" Control.Monad.Trans as Trans
12 import qualified Data.Accessor.Monad.Trans.State as MonadState
19 import qualified CoreSubst
21 -- import qualified CoreUtils
22 -- import Outputable ( showSDoc, ppr, nest )
25 import CLasH.Normalize.NormalizeTypes
26 import CLasH.Translator.TranslatorTypes
27 import CLasH.VHDL.Constants (builtinIds)
29 import qualified CLasH.Utils.Core.CoreTools as CoreTools
30 import qualified CLasH.VHDL.VHDLTools as VHDLTools
32 -- Apply the given transformation to all expressions in the given expression,
33 -- including the expression itself.
34 everywhere :: (String, Transform) -> Transform
35 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
37 -- Apply the first transformation, followed by the second transformation, and
38 -- keep applying both for as long as expression still changes.
39 applyboth :: Transform -> (String, Transform) -> Transform
40 applyboth first (name, second) context expr = do
42 expr' <- first context expr
44 (expr'', changed) <- Writer.listen $ second context expr'
46 -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n")
49 -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
50 -- ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
51 applyboth first (name, second) context expr''
53 -- trace ("No changes") $
56 -- Apply the given transformation to all direct subexpressions (only), not the
58 subeverywhere :: Transform -> Transform
59 subeverywhere trans c (App a b) = do
60 a' <- trans (AppFirst:c) a
61 b' <- trans (AppSecond:c) b
64 subeverywhere trans c (Let (NonRec b bexpr) expr) = do
65 bexpr' <- trans (Other:c) bexpr
66 expr' <- trans (Other:c) expr
67 return $ Let (NonRec b bexpr') expr'
69 subeverywhere trans c (Let (Rec binds) expr) = do
70 expr' <- trans (Other:c) expr
71 binds' <- mapM transbind binds
72 return $ Let (Rec binds') expr'
74 transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
76 e' <- trans (Other:c) e
79 subeverywhere trans c (Lam x expr) = do
80 expr' <- trans (Other:c) expr
83 subeverywhere trans c (Case scrut b t alts) = do
84 scrut' <- trans (Other:c) scrut
85 alts' <- mapM transalt alts
86 return $ Case scrut' b t alts'
88 transalt :: CoreAlt -> TransformMonad CoreAlt
89 transalt (con, binders, expr) = do
90 expr' <- trans (Other:c) expr
91 return (con, binders, expr')
93 subeverywhere trans c (Var x) = return $ Var x
94 subeverywhere trans c (Lit x) = return $ Lit x
95 subeverywhere trans c (Type x) = return $ Type x
97 subeverywhere trans c (Cast expr ty) = do
98 expr' <- trans (Other:c) expr
99 return $ Cast expr' ty
101 subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
103 -- Apply the given transformation to all expressions, except for direct
104 -- arguments of an application
105 notappargs :: (String, Transform) -> Transform
106 notappargs trans = applyboth (subnotappargs trans) trans
108 -- Apply the given transformation to all (direct and indirect) subexpressions
109 -- (but not the expression itself), except for direct arguments of an
111 subnotappargs :: (String, Transform) -> Transform
112 subnotappargs trans c (App a b) = do
113 a' <- subnotappargs trans (Other:c) a
114 b' <- subnotappargs trans (Other:c) b
117 -- Let subeverywhere handle all other expressions
118 subnotappargs trans c expr = subeverywhere (notappargs trans) c expr
120 -- Runs each of the transforms repeatedly inside the State monad.
121 dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
122 dotransforms transs expr = do
123 (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> trans [] e) expr transs
124 if Monoid.getAny changed then dotransforms transs expr' else return expr'
126 -- Inline all let bindings that satisfy the given condition
127 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
128 inlinebind condition context expr@(Let (NonRec bndr expr') res) = do
129 applies <- condition (bndr, expr')
132 -- Substitute the binding in res and return that
133 res' <- substitute_clone bndr expr' context res
136 -- Don't change this let
138 -- Leave all other expressions unchanged
139 inlinebind _ context expr = return expr
141 -- Sets the changed flag in the TransformMonad, to signify that some
142 -- transform has changed the result
143 setChanged :: TransformMonad ()
144 setChanged = Writer.tell (Monoid.Any True)
146 -- Sets the changed flag and returns the given value.
147 change :: a -> TransformMonad a
152 -- Returns the given value and sets the changed flag if the bool given is
153 -- True. Note that this will not unset the changed flag if the bool is False.
154 changeif :: Bool -> a -> TransformMonad a
155 changeif True val = change val
156 changeif False val = return val
158 -- | Creates a transformation that substitutes the given binder with the given
159 -- expression (This can be a type variable, replace by a Type expression).
160 -- Does not set the changed flag.
161 substitute :: CoreBndr -> CoreExpr -> Transform
162 -- Use CoreSubst to subst a type var in an expression
163 substitute find repl context expr = do
164 let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
165 return $ CoreSubst.substExpr subst expr
167 -- | Creates a transformation that substitutes the given binder with the given
168 -- expression. This does only work for value expressions! All binders in the
169 -- expression are cloned before the replacement, to guarantee uniqueness.
170 substitute_clone :: CoreBndr -> CoreExpr -> Transform
171 -- If we see the var to find, replace it by a uniqued version of repl
172 substitute_clone find repl context (Var var) | find == var = do
173 repl' <- Trans.lift $ CoreTools.genUniques repl
176 -- For all other expressions, just look in subexpressions
177 substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr
179 -- Is the given expression representable at runtime, based on the type?
180 isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
181 isRepr tything = Trans.lift (isRepr' tything)
183 isRepr' :: (CoreTools.TypedThing t) => t -> TranslatorSession Bool
184 isRepr' tything = case CoreTools.getType tything of
185 Nothing -> return False
186 Just ty -> MonadState.lift tsType $ VHDLTools.isReprType ty
188 is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
189 is_local_var (CoreSyn.Var v) = do
190 bndrs <- getGlobalBinders
191 return $ v `notElem` bndrs
192 is_local_var _ = return False
194 -- Is the given binder defined by the user?
195 isUserDefined :: CoreSyn.CoreBndr -> Bool
196 -- System names are certain to not be user defined
197 isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
198 -- Builtin functions are usually not user-defined either (and would
199 -- break currently if they are...)
200 isUserDefined bndr = str `notElem` builtinIds
202 str = Name.getOccString bndr
204 -- Is the given binder normalizable? This means that its type signature can be
205 -- represented in hardware, which should (?) guarantee that it can be made
206 -- into hardware. Note that if a binder is not normalizable, it might become
207 -- so using argument propagation.
208 isNormalizeable :: CoreBndr -> TransformMonad Bool
209 isNormalizeable bndr = Trans.lift (isNormalizeable' bndr)
211 isNormalizeable' :: CoreBndr -> TranslatorSession Bool
212 isNormalizeable' bndr = do
213 let ty = Id.idType bndr
214 let (arg_tys, res_ty) = Type.splitFunTys ty
215 -- This function is normalizable if all its arguments and return value are
217 andM $ mapM isRepr' (res_ty:arg_tys)