Allow normalized functions to have a non-representable result.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
1 {-# LANGUAGE PackageImports #-}
2 -- 
3 -- This module provides functions for program transformations.
4 --
5 module CLasH.Normalize.NormalizeTools where
6
7 -- Standard modules
8 import qualified Data.Monoid as Monoid
9 import qualified Data.Either as Either
10 import qualified Control.Monad as Monad
11 import qualified Control.Monad.Trans.Writer as Writer
12 import qualified "transformers" Control.Monad.Trans as Trans
13 import qualified Data.Accessor.Monad.Trans.State as MonadState
14 -- import Debug.Trace
15
16 -- GHC API
17 import CoreSyn
18 import qualified Name
19 import qualified Id
20 import qualified CoreSubst
21 import qualified Type
22 -- import qualified CoreUtils
23 -- import Outputable ( showSDoc, ppr, nest )
24
25 -- Local imports
26 import CLasH.Normalize.NormalizeTypes
27 import CLasH.Translator.TranslatorTypes
28 import CLasH.VHDL.Constants (builtinIds)
29 import CLasH.Utils
30 import qualified CLasH.Utils.Core.CoreTools as CoreTools
31 import qualified CLasH.VHDL.VHDLTools as VHDLTools
32
33 -- Apply the given transformation to all expressions in the given expression,
34 -- including the expression itself.
35 everywhere :: (String, Transform) -> Transform
36 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
37
38 -- Apply the first transformation, followed by the second transformation, and
39 -- keep applying both for as long as expression still changes.
40 applyboth :: Transform -> (String, Transform) -> Transform
41 applyboth first (name, second) context expr = do
42   -- Apply the first
43   expr' <- first context expr
44   -- Apply the second
45   (expr'', changed) <- Writer.listen $ second context expr'
46   if Monoid.getAny $
47         -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n")
48         changed 
49     then
50      -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
51      --        ++ "Context: " ++ show context ++ "\n"
52      --        ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
53       do
54         Trans.lift $ MonadState.modify tsTransformCounter (+1)
55         applyboth first (name, second) context expr'' 
56     else 
57       -- trace ("No changes") $
58       return expr''
59
60 -- Apply the given transformation to all direct subexpressions (only), not the
61 -- expression itself.
62 subeverywhere :: Transform -> Transform
63 subeverywhere trans c (App a b) = do
64   a' <- trans (AppFirst:c) a
65   b' <- trans (AppSecond:c) b
66   return $ App a' b'
67
68 subeverywhere trans c (Let (NonRec b bexpr) expr) = do
69   bexpr' <- trans (LetBinding:c) bexpr
70   expr' <- trans (LetBody:c) expr
71   return $ Let (NonRec b bexpr') expr'
72
73 subeverywhere trans c (Let (Rec binds) expr) = do
74   expr' <- trans (LetBody:c) expr
75   binds' <- mapM transbind binds
76   return $ Let (Rec binds') expr'
77   where
78     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
79     transbind (b, e) = do
80       e' <- trans (LetBinding:c) e
81       return (b, e')
82
83 subeverywhere trans c (Lam x expr) = do
84   expr' <- trans (LambdaBody:c) expr
85   return $ Lam x expr'
86
87 subeverywhere trans c (Case scrut b t alts) = do
88   scrut' <- trans (Other:c) scrut
89   alts' <- mapM transalt alts
90   return $ Case scrut' b t alts'
91   where
92     transalt :: CoreAlt -> TransformMonad CoreAlt
93     transalt (con, binders, expr) = do
94       expr' <- trans (Other:c) expr
95       return (con, binders, expr')
96
97 subeverywhere trans c (Var x) = return $ Var x
98 subeverywhere trans c (Lit x) = return $ Lit x
99 subeverywhere trans c (Type x) = return $ Type x
100
101 subeverywhere trans c (Cast expr ty) = do
102   expr' <- trans (Other:c) expr
103   return $ Cast expr' ty
104
105 subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
106
107 -- Runs each of the transforms repeatedly inside the State monad.
108 dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
109 dotransforms transs expr = do
110   (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> trans [] e) expr transs
111   if Monoid.getAny changed then dotransforms transs expr' else return expr'
112
113 -- Inline all let bindings that satisfy the given condition
114 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
115 inlinebind condition context expr@(Let (Rec binds) res) = do
116     -- Find all bindings that adhere to the condition
117     res_eithers <- mapM docond binds
118     case Either.partitionEithers res_eithers of
119       -- No replaces? No change
120       ([], _) -> return expr
121       (replace, others) -> do
122         -- Substitute the to be replaced binders with their expression
123         newexpr <- do_substitute replace (Let (Rec others) res)
124         change newexpr
125   where 
126     -- Apply the condition to a let binding and return an Either
127     -- depending on whether it needs to be inlined or not.
128     docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
129     docond b = do
130       res <- condition b
131       return $ case res of True -> Left b; False -> Right b
132
133     -- Apply the given list of substitutions to the the given expression
134     do_substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> TransformMonad CoreExpr
135     do_substitute [] expr = return expr
136     do_substitute ((bndr, val):reps) expr = do
137       -- Perform this substitution in the expression
138       expr' <- substitute_clone bndr val context expr
139       -- And in the substitution values we will be using next
140       reps' <- mapM (subs_bind bndr val) reps
141       -- And then perform the remaining substitutions
142       do_substitute reps' expr'
143    
144     -- Replace the given binder with the given expression in the
145     -- expression oft the given let binding
146     subs_bind :: CoreBndr -> CoreExpr -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
147     subs_bind bndr expr (b, v) = do
148       v' <- substitute_clone  bndr expr (LetBinding:context) v
149       return (b, v')
150
151
152 -- Leave all other expressions unchanged
153 inlinebind _ context expr = return expr
154
155 -- Sets the changed flag in the TransformMonad, to signify that some
156 -- transform has changed the result
157 setChanged :: TransformMonad ()
158 setChanged = Writer.tell (Monoid.Any True)
159
160 -- Sets the changed flag and returns the given value.
161 change :: a -> TransformMonad a
162 change val = do
163   setChanged
164   return val
165
166 -- Returns the given value and sets the changed flag if the bool given is
167 -- True. Note that this will not unset the changed flag if the bool is False.
168 changeif :: Bool -> a -> TransformMonad a
169 changeif True val = change val
170 changeif False val = return val
171
172 -- | Creates a transformation that substitutes the given binder with the given
173 -- expression (This can be a type variable, replace by a Type expression).
174 -- Does not set the changed flag.
175 substitute :: CoreBndr -> CoreExpr -> Transform
176 -- Use CoreSubst to subst a type var in an expression
177 substitute find repl context expr = do
178   let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
179   return $ CoreSubst.substExpr subst expr 
180
181 -- | Creates a transformation that substitutes the given binder with the given
182 -- expression. This does only work for value expressions! All binders in the
183 -- expression are cloned before the replacement, to guarantee uniqueness.
184 substitute_clone :: CoreBndr -> CoreExpr -> Transform
185 -- If we see the var to find, replace it by a uniqued version of repl
186 substitute_clone find repl context (Var var) | find == var = do
187   repl' <- Trans.lift $ CoreTools.genUniques repl
188   change repl'
189
190 -- For all other expressions, just look in subexpressions
191 substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr
192
193 -- Is the given expression representable at runtime, based on the type?
194 isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
195 isRepr tything = Trans.lift (isRepr' tything)
196
197 isRepr' :: (CoreTools.TypedThing t) => t -> TranslatorSession Bool
198 isRepr' tything = case CoreTools.getType tything of
199   Nothing -> return False
200   Just ty -> MonadState.lift tsType $ VHDLTools.isReprType ty 
201
202 is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
203 is_local_var (CoreSyn.Var v) = do
204   bndrs <- getGlobalBinders
205   return $ v `notElem` bndrs
206 is_local_var _ = return False
207
208 -- Is the given binder defined by the user?
209 isUserDefined :: CoreSyn.CoreBndr -> Bool
210 -- System names are certain to not be user defined
211 isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
212 -- Builtin functions are usually not user-defined either (and would
213 -- break currently if they are...)
214 isUserDefined bndr = str `notElem` builtinIds
215   where
216     str = Name.getOccString bndr
217
218 -- | Is the given binder normalizable? This means that its type signature can be
219 -- represented in hardware, which should (?) guarantee that it can be made
220 -- into hardware. This checks whether all the arguments and (optionally)
221 -- the return value are
222 -- representable.
223 isNormalizeable :: 
224   Bool -- ^ Allow the result to be unrepresentable?
225   -> CoreBndr  -- ^ The binder to check
226   -> TranslatorSession Bool  -- ^ Is it normalizeable?
227 isNormalizeable result_nonrep bndr = do
228   let ty = Id.idType bndr
229   let (arg_tys, res_ty) = Type.splitFunTys ty
230   let check_tys = if result_nonrep then arg_tys else (res_ty:arg_tys) 
231   andM $ mapM isRepr' check_tys