Make inlinebind work for non-recursive lets.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
1 {-# LANGUAGE PackageImports #-}
2 -- 
3 -- This module provides functions for program transformations.
4 --
5 module CLasH.Normalize.NormalizeTools where
6
7 -- Standard modules
8 import Debug.Trace
9 import qualified List
10 import qualified Data.Monoid as Monoid
11 import qualified Data.Either as Either
12 import qualified Control.Arrow as Arrow
13 import qualified Control.Monad as Monad
14 import qualified Control.Monad.Trans.State as State
15 import qualified Control.Monad.Trans.Writer as Writer
16 import qualified "transformers" Control.Monad.Trans as Trans
17 import qualified Data.Map as Map
18 import Data.Accessor
19 import Data.Accessor.MonadState as MonadState
20
21 -- GHC API
22 import CoreSyn
23 import qualified CoreSubst
24 import qualified CoreUtils
25 import Outputable ( showSDoc, ppr, nest )
26
27 -- Local imports
28 import CLasH.Normalize.NormalizeTypes
29 import CLasH.Translator.TranslatorTypes
30 import CLasH.Utils.Pretty
31 import CLasH.VHDL.VHDLTypes
32 import qualified CLasH.VHDL.VHDLTools as VHDLTools
33
34 -- Apply the given transformation to all expressions in the given expression,
35 -- including the expression itself.
36 everywhere :: (String, Transform) -> Transform
37 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
38
39 -- Apply the first transformation, followed by the second transformation, and
40 -- keep applying both for as long as expression still changes.
41 applyboth :: Transform -> (String, Transform) -> Transform
42 applyboth first (name, second) expr  = do
43   -- Apply the first
44   expr' <- first expr
45   -- Apply the second
46   (expr'', changed) <- Writer.listen $ second expr'
47   if Monoid.getAny $
48 --        trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
49         changed 
50     then 
51 --      trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
52 --      trace ("Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
53       applyboth first (name, second) $
54         expr'' 
55     else 
56 --      trace ("No changes") $
57       return expr''
58
59 -- Apply the given transformation to all direct subexpressions (only), not the
60 -- expression itself.
61 subeverywhere :: Transform -> Transform
62 subeverywhere trans (App a b) = do
63   a' <- trans a
64   b' <- trans b
65   return $ App a' b'
66
67 subeverywhere trans (Let (NonRec b bexpr) expr) = do
68   bexpr' <- trans bexpr
69   expr' <- trans expr
70   return $ Let (NonRec b bexpr') expr'
71
72 subeverywhere trans (Let (Rec binds) expr) = do
73   expr' <- trans expr
74   binds' <- mapM transbind binds
75   return $ Let (Rec binds') expr'
76   where
77     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
78     transbind (b, e) = do
79       e' <- trans e
80       return (b, e')
81
82 subeverywhere trans (Lam x expr) = do
83   expr' <- trans expr
84   return $ Lam x expr'
85
86 subeverywhere trans (Case scrut b t alts) = do
87   scrut' <- trans scrut
88   alts' <- mapM transalt alts
89   return $ Case scrut' b t alts'
90   where
91     transalt :: CoreAlt -> TransformMonad CoreAlt
92     transalt (con, binders, expr) = do
93       expr' <- trans expr
94       return (con, binders, expr')
95
96 subeverywhere trans (Var x) = return $ Var x
97 subeverywhere trans (Lit x) = return $ Lit x
98 subeverywhere trans (Type x) = return $ Type x
99
100 subeverywhere trans (Cast expr ty) = do
101   expr' <- trans expr
102   return $ Cast expr' ty
103
104 subeverywhere trans expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
105
106 -- Apply the given transformation to all expressions, except for direct
107 -- arguments of an application
108 notappargs :: (String, Transform) -> Transform
109 notappargs trans = applyboth (subnotappargs trans) trans
110
111 -- Apply the given transformation to all (direct and indirect) subexpressions
112 -- (but not the expression itself), except for direct arguments of an
113 -- application
114 subnotappargs :: (String, Transform) -> Transform
115 subnotappargs trans (App a b) = do
116   a' <- subnotappargs trans a
117   b' <- subnotappargs trans b
118   return $ App a' b'
119
120 -- Let subeverywhere handle all other expressions
121 subnotappargs trans expr = subeverywhere (notappargs trans) expr
122
123 -- Runs each of the transforms repeatedly inside the State monad.
124 dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
125 dotransforms transs expr = do
126   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
127   if Monoid.getAny changed then dotransforms transs expr' else return expr'
128
129 -- Inline all let bindings that satisfy the given condition
130 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
131 inlinebind condition expr@(Let (NonRec bndr expr') res) = do
132     applies <- condition (bndr, expr')
133     if applies
134       then
135         -- Substitute the binding in res and return that
136         change $ substitute [(bndr, expr')] res
137       else
138         -- Don't change this let
139         return expr
140 -- Leave all other expressions unchanged
141 inlinebind _ expr = return expr
142
143 -- Sets the changed flag in the TransformMonad, to signify that some
144 -- transform has changed the result
145 setChanged :: TransformMonad ()
146 setChanged = Writer.tell (Monoid.Any True)
147
148 -- Sets the changed flag and returns the given value.
149 change :: a -> TransformMonad a
150 change val = do
151   setChanged
152   return val
153
154 -- Returns the given value and sets the changed flag if the bool given is
155 -- True. Note that this will not unset the changed flag if the bool is False.
156 changeif :: Bool -> a -> TransformMonad a
157 changeif True val = change val
158 changeif False val = return val
159
160 -- Replace each of the binders given with the coresponding expressions in the
161 -- given expression.
162 substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
163 substitute [] expr = expr
164 -- Apply one substitution on the expression, but also on any remaining
165 -- substitutions. This seems to be the only way to handle substitutions like
166 -- [(b, c), (a, b)]. This means we reuse a substitution, which is not allowed
167 -- according to CoreSubst documentation (but it doesn't seem to be a problem).
168 -- TODO: Find out how this works, exactly.
169 substitute ((b, e):subss) expr = substitute subss' expr'
170   where 
171     -- Create the Subst
172     subs = (CoreSubst.extendSubst CoreSubst.emptySubst b e)
173     -- Apply this substitution to the main expression
174     expr' = CoreSubst.substExpr subs expr
175     -- Apply this substitution on all the expressions in the remaining
176     -- substitutions
177     subss' = map (Arrow.second (CoreSubst.substExpr subs)) subss
178
179 -- Is the given expression representable at runtime, based on the type?
180 isRepr :: CoreSyn.CoreExpr -> TransformMonad Bool
181 isRepr (Type ty) = return False
182 isRepr expr = Trans.lift $ MonadState.lift tsType $ VHDLTools.isReprType (CoreUtils.exprType expr)
183
184 is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
185 is_local_var (CoreSyn.Var v) = do
186   bndrs <- getGlobalBinders
187   return $ not $ v `elem` bndrs
188 is_local_var _ = return False