Learn flattenExpr to flatten normal applications.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import Control.Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified TyCon
8 import qualified Maybe
9 import Data.Traversable
10 import qualified DataCon
11 import qualified CoreUtils
12 import Control.Applicative
13 import Outputable ( showSDoc, ppr )
14 import qualified Data.Foldable as Foldable
15 import qualified Control.Monad.State as State
16
17 -- | A datatype that maps each of the single values in a haskell structure to
18 -- a mapto. The map has the same structure as the haskell type mapped, ie
19 -- nested tuples etc.
20 data HsValueMap mapto =
21   Tuple [HsValueMap mapto]
22   | Single mapto
23   deriving (Show, Eq)
24
25 instance Functor HsValueMap where
26   fmap f (Single s) = Single (f s)
27   fmap f (Tuple maps) = Tuple (map (fmap f) maps)
28
29 instance Foldable.Foldable HsValueMap where
30   foldMap f (Single s) = f s
31   -- The first foldMap folds a list of HsValueMaps, the second foldMap folds
32   -- each of the HsValueMaps in that list
33   foldMap f (Tuple maps) = Foldable.foldMap (Foldable.foldMap f) maps
34
35 instance Traversable HsValueMap where
36   traverse f (Single s) = Single <$> f s
37   traverse f (Tuple maps) = Tuple <$> (traverse (traverse f) maps)
38
39 data PassState s x = PassState (s -> (s, x))
40
41 instance Functor (PassState s) where
42   fmap f (PassState a) = PassState (\s -> let (s', a') = a s in (s', f a'))
43
44 instance Applicative (PassState s) where
45   pure x = PassState (\s -> (s, x))
46   PassState f <*> PassState x = PassState (\s -> let (s', f') = f s; (s'', x') = x s' in (s'', f' x'))
47
48 -- | Creates a HsValueMap with the same structure as the given type, using the
49 --   given function for mapping the single types.
50 mkHsValueMap ::
51   Type.Type                         -- ^ The type to map to a HsValueMap
52   -> HsValueMap Type.Type           -- ^ The resulting map and state
53
54 mkHsValueMap ty =
55   case Type.splitTyConApp_maybe ty of
56     Just (tycon, args) ->
57       if (TyCon.isTupleTyCon tycon) 
58         then
59           Tuple (map mkHsValueMap args)
60         else
61           Single ty
62     Nothing -> Single ty
63
64 -- Extract the arguments from a data constructor application (that is, the
65 -- normal args, leaving out the type args).
66 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
67 dataConAppArgs dc args =
68     drop tycount args
69   where
70     tycount = length $ DataCon.dataConAllTyVars dc
71
72
73
74 data FlatFunction = FlatFunction {
75   args   :: [SignalDefMap],
76   res    :: SignalUseMap,
77   --sigs   :: [SignalDef],
78   apps   :: [FApp],
79   conds  :: [CondDef]
80 } deriving (Show, Eq)
81     
82 type SignalUseMap = HsValueMap SignalUse
83 type SignalDefMap = HsValueMap SignalDef
84
85 useMapToDefMap :: SignalUseMap -> SignalDefMap
86 useMapToDefMap = fmap (\(SignalUse u) -> SignalDef u)
87
88 defMapToUseMap :: SignalDefMap -> SignalUseMap
89 defMapToUseMap = fmap (\(SignalDef u) -> SignalUse u)
90
91
92 type SignalId = Int
93 data SignalUse = SignalUse {
94   sigUseId :: SignalId
95 } deriving (Show, Eq)
96
97 data SignalDef = SignalDef {
98   sigDefId :: SignalId
99 } deriving (Show, Eq)
100
101 data FApp = FApp {
102   appFunc :: HsFunction,
103   appArgs :: [SignalUseMap],
104   appRes  :: SignalDefMap
105 } deriving (Show, Eq)
106
107 data CondDef = CondDef {
108   cond    :: SignalUse,
109   high    :: SignalUse,
110   low     :: SignalUse,
111   condRes :: SignalDef
112 } deriving (Show, Eq)
113
114 -- | How is a given (single) value in a function's type (ie, argument or
115 -- return value) used?
116 data HsValueUse = 
117   Port           -- ^ Use it as a port (input or output)
118   | State Int    -- ^ Use it as state (input or output). The int is used to
119                  --   match input state to output state.
120   | HighOrder {  -- ^ Use it as a high order function input
121     hoName :: String,  -- ^ Which function is passed in?
122     hoArgs :: [HsUseMap]   -- ^ Which arguments are already applied? This
123                          -- ^ map should only contain Port and other
124                          --   HighOrder values. 
125   }
126   deriving (Show, Eq)
127
128 type HsUseMap = HsValueMap HsValueUse
129
130 -- | Builds a HsUseMap with the same structure has the given HsValueMap in
131 --   which all the Single elements are marked as State, with increasing state
132 --   numbers.
133 useAsState :: HsValueMap a -> HsUseMap
134 useAsState map =
135   map'
136   where
137     -- Traverse the existing map, resulting in a function that maps an initial
138     -- state number to the final state number and the new map
139     PassState f = traverse asState map
140     -- Run this function to get the new map
141     (_, map')   = f 0
142     -- This function maps each element to a State with a unique number, by
143     -- incrementing the state count.
144     asState x   = PassState (\s -> (s+1, State s))
145
146 -- | Builds a HsUseMap with the same structure has the given HsValueMap in
147 --   which all the Single elements are marked as Port.
148 useAsPort :: HsValueMap a -> HsUseMap
149 useAsPort map = fmap (\x -> Port) map
150
151 data HsFunction = HsFunction {
152   hsFuncName :: String,
153   hsFuncArgs :: [HsUseMap],
154   hsFuncRes  :: HsUseMap
155 } deriving (Show, Eq)
156
157 type BindMap = [(
158   CoreBndr,            -- ^ The bind name
159   Either               -- ^ The bind value which is either
160     SignalUseMap       -- ^ a signal
161     (
162       HsValueUse,      -- ^ or a HighOrder function
163       [SignalUse]      -- ^ With these signals already applied to it
164     )
165   )]
166
167 type FlattenState = State.State ([FApp], [CondDef], SignalId)
168
169 -- | Add an application to the current FlattenState
170 addApp :: FApp -> FlattenState ()
171 addApp a = do
172   (apps, conds, n) <- State.get
173   State.put (a:apps, conds, n)
174
175 -- | Add a conditional definition to the current FlattenState
176 addCondDef :: CondDef -> FlattenState ()
177 addCondDef c = do
178   (apps, conds, n) <- State.get
179   State.put (apps, c:conds, n)
180
181 -- | Generates a new signal id, which is unique within the current flattening.
182 genSignalId :: FlattenState SignalId 
183 genSignalId = do
184   (apps, conds, n) <- State.get
185   State.put (apps, conds, n+1)
186   return n
187
188 genSignalUses ::
189   Type.Type
190   -> FlattenState SignalUseMap
191
192 genSignalUses ty = do
193   typeMapToUseMap tymap
194   where
195     -- First generate a map with the right structure containing the types
196     tymap = mkHsValueMap ty
197
198 typeMapToUseMap ::
199   HsValueMap Type.Type
200   -> FlattenState SignalUseMap
201
202 typeMapToUseMap (Single ty) = do
203   id <- genSignalId
204   return $ Single (SignalUse id)
205
206 typeMapToUseMap (Tuple tymaps) = do
207   usemaps <- State.mapM typeMapToUseMap tymaps
208   return $ Tuple usemaps
209
210 -- | Flatten a haskell function
211 flattenFunction ::
212   HsFunction                      -- ^ The function to flatten
213   -> CoreBind                     -- ^ The function value
214   -> FlatFunction                 -- ^ The resulting flat function
215
216 flattenFunction _ (Rec _) = error "Recursive binders not supported"
217 flattenFunction hsfunc bind@(NonRec var expr) =
218   FlatFunction args res apps conds
219   where
220     init_state        = ([], [], 0)
221     (fres, end_state) = State.runState (flattenExpr [] expr) init_state
222     (args, res)       = fres
223     (apps, conds, _)  = end_state
224
225 flattenExpr ::
226   BindMap
227   -> CoreExpr
228   -> FlattenState ([SignalDefMap], SignalUseMap)
229
230 flattenExpr binds lam@(Lam b expr) = do
231   -- Find the type of the binder
232   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
233   -- Create signal names for the binder
234   defs <- genSignalUses arg_ty
235   let binds' = (b, Left defs):binds
236   (args, res) <- flattenExpr binds' expr
237   return ((useMapToDefMap defs) : args, res)
238
239 flattenExpr binds (Var id) =
240   case bind of
241     Left sig_use -> return ([], sig_use)
242     Right _ -> error "Higher order functions not supported."
243   where
244     bind = Maybe.fromMaybe
245       (error $ "Argument " ++ Name.getOccString id ++ "is unknown")
246       (lookup id binds)
247
248 flattenExpr binds app@(App _ _) = do
249   -- Is this a data constructor application?
250   case CoreUtils.exprIsConApp_maybe app of
251     -- Is this a tuple construction?
252     Just (dc, args) -> if DataCon.isTupleCon dc 
253       then
254         flattenBuildTupleExpr binds (dataConAppArgs dc args)
255       else
256         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
257     otherwise ->
258       -- Normal function application
259       let ((Var f), args) = collectArgs app in
260       flattenApplicationExpr binds (CoreUtils.exprType app) f args
261   where
262     flattenBuildTupleExpr = error $ "Tuple construction not supported: " ++ (showSDoc $ ppr app)
263     -- | Flatten a normal application expression
264     flattenApplicationExpr binds ty f args = do
265       -- Find the function to call
266       let func = appToHsFunction ty f args
267       -- Flatten each of our args
268       flat_args <- (State.mapM (flattenExpr binds) args)
269       -- Check and split each of the arguments
270       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
271       -- Generate signals for our result
272       res <- genSignalUses ty
273       -- Create the function application
274       let app = FApp {
275         appFunc = func,
276         appArgs = arg_ress,
277         appRes  = useMapToDefMap res
278       }
279       addApp app
280       return ([], res)
281     -- | Check a flattened expression to see if it is valid to use as a
282     --   function argument. The first argument is the original expression for
283     --   use in the error message.
284     checkArg arg flat =
285       let (args, res) = flat in
286       if not (null args)
287         then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
288         else flat 
289
290 flattenExpr _ _ = do
291   return ([], Tuple [])
292
293 appToHsFunction ::
294   Type.Type       -- ^ The return type
295   -> Var.Var      -- ^ The function to call
296   -> [CoreExpr]   -- ^ The function arguments
297   -> HsFunction   -- ^ The needed HsFunction
298
299 appToHsFunction ty f args =
300   HsFunction hsname hsargs hsres
301   where
302     hsname = Name.getOccString f
303     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
304     hsres  = useAsPort (mkHsValueMap ty)
305
306 -- vim: set ts=8 sw=2 sts=2 expandtab: