Revert "Make inlinebind work for non-recursive lets."
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
1 {-# LANGUAGE PackageImports #-}
2 -- 
3 -- This module provides functions for program transformations.
4 --
5 module CLasH.Normalize.NormalizeTools where
6
7 -- Standard modules
8 import qualified Data.Monoid as Monoid
9 import qualified Data.Either as Either
10 import qualified Control.Monad as Monad
11 import qualified Control.Monad.Trans.Writer as Writer
12 import qualified "transformers" Control.Monad.Trans as Trans
13 import qualified Data.Accessor.Monad.Trans.State as MonadState
14 -- import Debug.Trace
15
16 -- GHC API
17 import CoreSyn
18 import qualified Name
19 import qualified Id
20 import qualified CoreSubst
21 import qualified Type
22 -- import qualified CoreUtils
23 -- import Outputable ( showSDoc, ppr, nest )
24
25 -- Local imports
26 import CLasH.Normalize.NormalizeTypes
27 import CLasH.Translator.TranslatorTypes
28 import CLasH.VHDL.Constants (builtinIds)
29 import CLasH.Utils
30 import qualified CLasH.Utils.Core.CoreTools as CoreTools
31 import qualified CLasH.VHDL.VHDLTools as VHDLTools
32
33 -- Apply the given transformation to all expressions in the given expression,
34 -- including the expression itself.
35 everywhere :: (String, Transform) -> Transform
36 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
37
38 -- Apply the first transformation, followed by the second transformation, and
39 -- keep applying both for as long as expression still changes.
40 applyboth :: Transform -> (String, Transform) -> Transform
41 applyboth first (name, second) context expr = do
42   -- Apply the first
43   expr' <- first context expr
44   -- Apply the second
45   (expr'', changed) <- Writer.listen $ second context expr'
46   if Monoid.getAny $
47         -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n")
48         changed 
49     then
50      -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
51      --        ++ "Context: " ++ show context ++ "\n"
52      --        ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
53       do
54         Trans.lift $ MonadState.modify tsTransformCounter (+1)
55         applyboth first (name, second) context expr'' 
56     else 
57       -- trace ("No changes") $
58       return expr''
59
60 -- Apply the given transformation to all direct subexpressions (only), not the
61 -- expression itself.
62 subeverywhere :: Transform -> Transform
63 subeverywhere trans c (App a b) = do
64   a' <- trans (AppFirst:c) a
65   b' <- trans (AppSecond:c) b
66   return $ App a' b'
67
68 subeverywhere trans c (Let (NonRec b bexpr) expr) = do
69   bexpr' <- trans (LetBinding:c) bexpr
70   expr' <- trans (LetBody:c) expr
71   return $ Let (NonRec b bexpr') expr'
72
73 subeverywhere trans c (Let (Rec binds) expr) = do
74   expr' <- trans (LetBody:c) expr
75   binds' <- mapM transbind binds
76   return $ Let (Rec binds') expr'
77   where
78     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
79     transbind (b, e) = do
80       e' <- trans (LetBinding:c) e
81       return (b, e')
82
83 subeverywhere trans c (Lam x expr) = do
84   expr' <- trans (LambdaBody:c) expr
85   return $ Lam x expr'
86
87 subeverywhere trans c (Case scrut b t alts) = do
88   scrut' <- trans (Other:c) scrut
89   alts' <- mapM transalt alts
90   return $ Case scrut' b t alts'
91   where
92     transalt :: CoreAlt -> TransformMonad CoreAlt
93     transalt (con, binders, expr) = do
94       expr' <- trans (Other:c) expr
95       return (con, binders, expr')
96
97 subeverywhere trans c (Var x) = return $ Var x
98 subeverywhere trans c (Lit x) = return $ Lit x
99 subeverywhere trans c (Type x) = return $ Type x
100
101 subeverywhere trans c (Cast expr ty) = do
102   expr' <- trans (Other:c) expr
103   return $ Cast expr' ty
104
105 subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
106
107 -- Runs each of the transforms repeatedly inside the State monad.
108 dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
109 dotransforms transs expr = do
110   (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> trans [] e) expr transs
111   if Monoid.getAny changed then dotransforms transs expr' else return expr'
112
113 -- Inline all let bindings that satisfy the given condition
114 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
115 inlinebind condition context expr@(Let (Rec binds) res) = do
116     -- Find all bindings that adhere to the condition
117     res_eithers <- mapM docond binds
118     case Either.partitionEithers res_eithers of
119       -- No replaces? No change
120       ([], _) -> return expr
121       (replace, others) -> do
122         -- Substitute the to be replaced binders with their expression
123         newexpr <- Monad.foldM (\e (bndr, repl) -> substitute_clone bndr repl context e) (Let (Rec others) res) replace
124         change newexpr
125   where 
126     docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
127     docond b = do
128       res <- condition b
129       return $ case res of True -> Left b; False -> Right b
130
131 -- Leave all other expressions unchanged
132 inlinebind _ context expr = return expr
133
134 -- Sets the changed flag in the TransformMonad, to signify that some
135 -- transform has changed the result
136 setChanged :: TransformMonad ()
137 setChanged = Writer.tell (Monoid.Any True)
138
139 -- Sets the changed flag and returns the given value.
140 change :: a -> TransformMonad a
141 change val = do
142   setChanged
143   return val
144
145 -- Returns the given value and sets the changed flag if the bool given is
146 -- True. Note that this will not unset the changed flag if the bool is False.
147 changeif :: Bool -> a -> TransformMonad a
148 changeif True val = change val
149 changeif False val = return val
150
151 -- | Creates a transformation that substitutes the given binder with the given
152 -- expression (This can be a type variable, replace by a Type expression).
153 -- Does not set the changed flag.
154 substitute :: CoreBndr -> CoreExpr -> Transform
155 -- Use CoreSubst to subst a type var in an expression
156 substitute find repl context expr = do
157   let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
158   return $ CoreSubst.substExpr subst expr 
159
160 -- | Creates a transformation that substitutes the given binder with the given
161 -- expression. This does only work for value expressions! All binders in the
162 -- expression are cloned before the replacement, to guarantee uniqueness.
163 substitute_clone :: CoreBndr -> CoreExpr -> Transform
164 -- If we see the var to find, replace it by a uniqued version of repl
165 substitute_clone find repl context (Var var) | find == var = do
166   repl' <- Trans.lift $ CoreTools.genUniques repl
167   change repl'
168
169 -- For all other expressions, just look in subexpressions
170 substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr
171
172 -- Is the given expression representable at runtime, based on the type?
173 isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
174 isRepr tything = Trans.lift (isRepr' tything)
175
176 isRepr' :: (CoreTools.TypedThing t) => t -> TranslatorSession Bool
177 isRepr' tything = case CoreTools.getType tything of
178   Nothing -> return False
179   Just ty -> MonadState.lift tsType $ VHDLTools.isReprType ty 
180
181 is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
182 is_local_var (CoreSyn.Var v) = do
183   bndrs <- getGlobalBinders
184   return $ v `notElem` bndrs
185 is_local_var _ = return False
186
187 -- Is the given binder defined by the user?
188 isUserDefined :: CoreSyn.CoreBndr -> Bool
189 -- System names are certain to not be user defined
190 isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
191 -- Builtin functions are usually not user-defined either (and would
192 -- break currently if they are...)
193 isUserDefined bndr = str `notElem` builtinIds
194   where
195     str = Name.getOccString bndr
196
197 -- Is the given binder normalizable? This means that its type signature can be
198 -- represented in hardware, which should (?) guarantee that it can be made
199 -- into hardware. Note that if a binder is not normalizable, it might become
200 -- so using argument propagation.
201 isNormalizeable :: CoreBndr -> TransformMonad Bool 
202 isNormalizeable bndr = Trans.lift (isNormalizeable' bndr)
203
204 isNormalizeable' :: CoreBndr -> TranslatorSession Bool 
205 isNormalizeable' bndr = do
206   let ty = Id.idType bndr
207   let (arg_tys, res_ty) = Type.splitFunTys ty
208   -- This function is normalizable if all its arguments and return value are
209   -- representable.
210   andM $ mapM isRepr' (res_ty:arg_tys)