Add changeif normalization helper function.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
1 {-# LANGUAGE PackageImports #-}
2 -- 
3 -- This module provides functions for program transformations.
4 --
5 module CLasH.Normalize.NormalizeTools where
6
7 -- Standard modules
8 import Debug.Trace
9 import qualified List
10 import qualified Data.Monoid as Monoid
11 import qualified Data.Either as Either
12 import qualified Control.Arrow as Arrow
13 import qualified Control.Monad as Monad
14 import qualified Control.Monad.Trans.State as State
15 import qualified Control.Monad.Trans.Writer as Writer
16 import qualified "transformers" Control.Monad.Trans as Trans
17 import qualified Data.Map as Map
18 import Data.Accessor
19 import Data.Accessor.MonadState as MonadState
20
21 -- GHC API
22 import CoreSyn
23 import qualified CoreSubst
24 import qualified CoreUtils
25 import Outputable ( showSDoc, ppr, nest )
26
27 -- Local imports
28 import CLasH.Normalize.NormalizeTypes
29 import CLasH.Translator.TranslatorTypes
30 import CLasH.Utils.Pretty
31 import CLasH.VHDL.VHDLTypes
32 import qualified CLasH.VHDL.VHDLTools as VHDLTools
33
34 -- Apply the given transformation to all expressions in the given expression,
35 -- including the expression itself.
36 everywhere :: (String, Transform) -> Transform
37 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
38
39 -- Apply the first transformation, followed by the second transformation, and
40 -- keep applying both for as long as expression still changes.
41 applyboth :: Transform -> (String, Transform) -> Transform
42 applyboth first (name, second) expr  = do
43   -- Apply the first
44   expr' <- first expr
45   -- Apply the second
46   (expr'', changed) <- Writer.listen $ second expr'
47   if Monoid.getAny $
48 --        trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
49         changed 
50     then 
51 --      trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
52 --      trace ("Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
53       applyboth first (name, second) $
54         expr'' 
55     else 
56 --      trace ("No changes") $
57       return expr''
58
59 -- Apply the given transformation to all direct subexpressions (only), not the
60 -- expression itself.
61 subeverywhere :: Transform -> Transform
62 subeverywhere trans (App a b) = do
63   a' <- trans a
64   b' <- trans b
65   return $ App a' b'
66
67 subeverywhere trans (Let (NonRec b bexpr) expr) = do
68   bexpr' <- trans bexpr
69   expr' <- trans expr
70   return $ Let (NonRec b bexpr') expr'
71
72 subeverywhere trans (Let (Rec binds) expr) = do
73   expr' <- trans expr
74   binds' <- mapM transbind binds
75   return $ Let (Rec binds') expr'
76   where
77     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
78     transbind (b, e) = do
79       e' <- trans e
80       return (b, e')
81
82 subeverywhere trans (Lam x expr) = do
83   expr' <- trans expr
84   return $ Lam x expr'
85
86 subeverywhere trans (Case scrut b t alts) = do
87   scrut' <- trans scrut
88   alts' <- mapM transalt alts
89   return $ Case scrut' b t alts'
90   where
91     transalt :: CoreAlt -> TransformMonad CoreAlt
92     transalt (con, binders, expr) = do
93       expr' <- trans expr
94       return (con, binders, expr')
95
96 subeverywhere trans (Var x) = return $ Var x
97 subeverywhere trans (Lit x) = return $ Lit x
98 subeverywhere trans (Type x) = return $ Type x
99
100 subeverywhere trans (Cast expr ty) = do
101   expr' <- trans expr
102   return $ Cast expr' ty
103
104 subeverywhere trans expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
105
106 -- Apply the given transformation to all expressions, except for direct
107 -- arguments of an application
108 notappargs :: (String, Transform) -> Transform
109 notappargs trans = applyboth (subnotappargs trans) trans
110
111 -- Apply the given transformation to all (direct and indirect) subexpressions
112 -- (but not the expression itself), except for direct arguments of an
113 -- application
114 subnotappargs :: (String, Transform) -> Transform
115 subnotappargs trans (App a b) = do
116   a' <- subnotappargs trans a
117   b' <- subnotappargs trans b
118   return $ App a' b'
119
120 -- Let subeverywhere handle all other expressions
121 subnotappargs trans expr = subeverywhere (notappargs trans) expr
122
123 -- Runs each of the transforms repeatedly inside the State monad.
124 dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
125 dotransforms transs expr = do
126   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
127   if Monoid.getAny changed then dotransforms transs expr' else return expr'
128
129 -- Inline all let bindings that satisfy the given condition
130 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
131 inlinebind condition expr@(Let (Rec binds) res) = do
132     -- Find all bindings that adhere to the condition
133     res_eithers <- mapM docond binds
134     case Either.partitionEithers res_eithers of
135       -- No replaces? No change
136       ([], _) -> return expr
137       (replace, others) -> do
138         -- Substitute the to be replaced binders with their expression
139         let newexpr = substitute replace (Let (Rec others) res)
140         change newexpr
141   where 
142     docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
143     docond b = do
144       res <- condition b
145       return $ case res of True -> Left b; False -> Right b
146
147 -- Leave all other expressions unchanged
148 inlinebind _ expr = return expr
149
150 -- Sets the changed flag in the TransformMonad, to signify that some
151 -- transform has changed the result
152 setChanged :: TransformMonad ()
153 setChanged = Writer.tell (Monoid.Any True)
154
155 -- Sets the changed flag and returns the given value.
156 change :: a -> TransformMonad a
157 change val = do
158   setChanged
159   return val
160
161 -- Returns the given value and sets the changed flag if the bool given is
162 -- True. Note that this will not unset the changed flag if the bool is False.
163 changeif :: Bool -> a -> TransformMonad a
164 changeif True val = change val
165 changeif False val = return val
166
167 -- Replace each of the binders given with the coresponding expressions in the
168 -- given expression.
169 substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
170 substitute [] expr = expr
171 -- Apply one substitution on the expression, but also on any remaining
172 -- substitutions. This seems to be the only way to handle substitutions like
173 -- [(b, c), (a, b)]. This means we reuse a substitution, which is not allowed
174 -- according to CoreSubst documentation (but it doesn't seem to be a problem).
175 -- TODO: Find out how this works, exactly.
176 substitute ((b, e):subss) expr = substitute subss' expr'
177   where 
178     -- Create the Subst
179     subs = (CoreSubst.extendSubst CoreSubst.emptySubst b e)
180     -- Apply this substitution to the main expression
181     expr' = CoreSubst.substExpr subs expr
182     -- Apply this substitution on all the expressions in the remaining
183     -- substitutions
184     subss' = map (Arrow.second (CoreSubst.substExpr subs)) subss
185
186 -- Is the given expression representable at runtime, based on the type?
187 isRepr :: CoreSyn.CoreExpr -> TransformMonad Bool
188 isRepr (Type ty) = return False
189 isRepr expr = Trans.lift $ MonadState.lift tsType $ VHDLTools.isReprType (CoreUtils.exprType expr)
190
191 is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
192 is_local_var (CoreSyn.Var v) = do
193   bndrs <- getGlobalBinders
194   return $ not $ v `elem` bndrs
195 is_local_var _ = return False