Use the list of builtin functions for inlining.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
1 {-# LANGUAGE PackageImports #-}
2 -- 
3 -- This module provides functions for program transformations.
4 --
5 module CLasH.Normalize.NormalizeTools where
6
7 -- Standard modules
8 import qualified Data.Monoid as Monoid
9 import qualified Control.Monad as Monad
10 import qualified Control.Monad.Trans.Writer as Writer
11 import qualified "transformers" Control.Monad.Trans as Trans
12 import qualified Data.Accessor.Monad.Trans.State as MonadState
13 -- import Debug.Trace
14
15 -- GHC API
16 import CoreSyn
17 import qualified Name
18 import qualified Id
19 import qualified CoreSubst
20 import qualified Type
21 -- import qualified CoreUtils
22 -- import Outputable ( showSDoc, ppr, nest )
23
24 -- Local imports
25 import CLasH.Normalize.NormalizeTypes
26 import CLasH.Translator.TranslatorTypes
27 import CLasH.VHDL.Constants (builtinIds)
28 import CLasH.Utils
29 import qualified CLasH.Utils.Core.CoreTools as CoreTools
30 import qualified CLasH.VHDL.VHDLTools as VHDLTools
31
32 -- Apply the given transformation to all expressions in the given expression,
33 -- including the expression itself.
34 everywhere :: (String, Transform) -> Transform
35 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
36
37 -- Apply the first transformation, followed by the second transformation, and
38 -- keep applying both for as long as expression still changes.
39 applyboth :: Transform -> (String, Transform) -> Transform
40 applyboth first (name, second) expr = do
41   -- Apply the first
42   expr' <- first expr
43   -- Apply the second
44   (expr'', changed) <- Writer.listen $ second expr'
45   if Monoid.getAny $
46         -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n")
47         changed 
48     then 
49      -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
50      --        ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
51       applyboth first (name, second)
52         expr'' 
53     else 
54       -- trace ("No changes") $
55       return expr''
56
57 -- Apply the given transformation to all direct subexpressions (only), not the
58 -- expression itself.
59 subeverywhere :: Transform -> Transform
60 subeverywhere trans (App a b) = do
61   a' <- trans a
62   b' <- trans b
63   return $ App a' b'
64
65 subeverywhere trans (Let (NonRec b bexpr) expr) = do
66   bexpr' <- trans bexpr
67   expr' <- trans expr
68   return $ Let (NonRec b bexpr') expr'
69
70 subeverywhere trans (Let (Rec binds) expr) = do
71   expr' <- trans expr
72   binds' <- mapM transbind binds
73   return $ Let (Rec binds') expr'
74   where
75     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
76     transbind (b, e) = do
77       e' <- trans e
78       return (b, e')
79
80 subeverywhere trans (Lam x expr) = do
81   expr' <- trans expr
82   return $ Lam x expr'
83
84 subeverywhere trans (Case scrut b t alts) = do
85   scrut' <- trans scrut
86   alts' <- mapM transalt alts
87   return $ Case scrut' b t alts'
88   where
89     transalt :: CoreAlt -> TransformMonad CoreAlt
90     transalt (con, binders, expr) = do
91       expr' <- trans expr
92       return (con, binders, expr')
93
94 subeverywhere trans (Var x) = return $ Var x
95 subeverywhere trans (Lit x) = return $ Lit x
96 subeverywhere trans (Type x) = return $ Type x
97
98 subeverywhere trans (Cast expr ty) = do
99   expr' <- trans expr
100   return $ Cast expr' ty
101
102 subeverywhere trans expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
103
104 -- Apply the given transformation to all expressions, except for direct
105 -- arguments of an application
106 notappargs :: (String, Transform) -> Transform
107 notappargs trans = applyboth (subnotappargs trans) trans
108
109 -- Apply the given transformation to all (direct and indirect) subexpressions
110 -- (but not the expression itself), except for direct arguments of an
111 -- application
112 subnotappargs :: (String, Transform) -> Transform
113 subnotappargs trans (App a b) = do
114   a' <- subnotappargs trans a
115   b' <- subnotappargs trans b
116   return $ App a' b'
117
118 -- Let subeverywhere handle all other expressions
119 subnotappargs trans expr = subeverywhere (notappargs trans) expr
120
121 -- Runs each of the transforms repeatedly inside the State monad.
122 dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
123 dotransforms transs expr = do
124   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
125   if Monoid.getAny changed then dotransforms transs expr' else return expr'
126
127 -- Inline all let bindings that satisfy the given condition
128 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
129 inlinebind condition expr@(Let (NonRec bndr expr') res) = do
130     applies <- condition (bndr, expr')
131     if applies
132       then do
133         -- Substitute the binding in res and return that
134         res' <- substitute_clone bndr expr' res
135         change res'
136       else
137         -- Don't change this let
138         return expr
139 -- Leave all other expressions unchanged
140 inlinebind _ expr = return expr
141
142 -- Sets the changed flag in the TransformMonad, to signify that some
143 -- transform has changed the result
144 setChanged :: TransformMonad ()
145 setChanged = Writer.tell (Monoid.Any True)
146
147 -- Sets the changed flag and returns the given value.
148 change :: a -> TransformMonad a
149 change val = do
150   setChanged
151   return val
152
153 -- Returns the given value and sets the changed flag if the bool given is
154 -- True. Note that this will not unset the changed flag if the bool is False.
155 changeif :: Bool -> a -> TransformMonad a
156 changeif True val = change val
157 changeif False val = return val
158
159 -- | Creates a transformation that substitutes the given binder with the given
160 -- expression (This can be a type variable, replace by a Type expression).
161 -- Does not set the changed flag.
162 substitute :: CoreBndr -> CoreExpr -> Transform
163 -- Use CoreSubst to subst a type var in an expression
164 substitute find repl expr = do
165   let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
166   return $ CoreSubst.substExpr subst expr 
167
168 -- | Creates a transformation that substitutes the given binder with the given
169 -- expression. This does only work for value expressions! All binders in the
170 -- expression are cloned before the replacement, to guarantee uniqueness.
171 substitute_clone :: CoreBndr -> CoreExpr -> Transform
172 -- If we see the var to find, replace it by a uniqued version of repl
173 substitute_clone find repl (Var var) | find == var = do
174   repl' <- Trans.lift $ CoreTools.genUniques repl
175   change repl'
176
177 -- For all other expressions, just look in subexpressions
178 substitute_clone find repl expr = subeverywhere (substitute_clone find repl) expr
179
180 -- Is the given expression representable at runtime, based on the type?
181 isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
182 isRepr tything = Trans.lift (isRepr' tything)
183
184 isRepr' :: (CoreTools.TypedThing t) => t -> TranslatorSession Bool
185 isRepr' tything = case CoreTools.getType tything of
186   Nothing -> return False
187   Just ty -> MonadState.lift tsType $ VHDLTools.isReprType ty 
188
189 is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
190 is_local_var (CoreSyn.Var v) = do
191   bndrs <- getGlobalBinders
192   return $ v `notElem` bndrs
193 is_local_var _ = return False
194
195 -- Is the given binder defined by the user?
196 isUserDefined :: CoreSyn.CoreBndr -> Bool
197 -- System names are certain to not be user defined
198 isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
199 -- Builtin functions are usually not user-defined either (and would
200 -- break currently if they are...)
201 isUserDefined bndr = str `notElem` builtinIds
202   where
203     str = Name.getOccString bndr
204
205 -- Is the given binder normalizable? This means that its type signature can be
206 -- represented in hardware, which should (?) guarantee that it can be made
207 -- into hardware. Note that if a binder is not normalizable, it might become
208 -- so using argument propagation.
209 isNormalizeable :: CoreBndr -> TransformMonad Bool 
210 isNormalizeable bndr = Trans.lift (isNormalizeable' bndr)
211
212 isNormalizeable' :: CoreBndr -> TranslatorSession Bool 
213 isNormalizeable' bndr = do
214   let ty = Id.idType bndr
215   let (arg_tys, res_ty) = Type.splitFunTys ty
216   -- This function is normalizable if all its arguments and return value are
217   -- representable.
218   andM $ mapM isRepr' (res_ty:arg_tys)