Use a different approach for marking SigUses.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import Control.Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified CoreUtils
11 import qualified Data.Traversable as Traversable
12 import qualified Data.Foldable as Foldable
13 import Control.Applicative
14 import Outputable ( showSDoc, ppr )
15 import qualified Control.Monad.State as State
16
17 import HsValueMap
18 import TranslatorTypes
19 import FlattenTypes
20
21 -- Extract the arguments from a data constructor application (that is, the
22 -- normal args, leaving out the type args).
23 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
24 dataConAppArgs dc args =
25     drop tycount args
26   where
27     tycount = length $ DataCon.dataConAllTyVars dc
28
29 genSignals ::
30   Type.Type
31   -> FlattenState SignalMap
32
33 genSignals ty =
34   -- First generate a map with the right structure containing the types, and
35   -- generate signals for each of them.
36   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
37
38 -- | Marks a signal as the given SigUse, if its id is in the list of id's
39 --   given.
40 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
41 markSignals use ids (id, info) =
42   (id, info')
43   where
44     info' = if id `elem` ids then info { sigUse = use} else info
45
46 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
47 markSignal use id = markSignals use [id]
48
49 -- | Flatten a haskell function
50 flattenFunction ::
51   HsFunction                      -- ^ The function to flatten
52   -> CoreBind                     -- ^ The function value
53   -> FlatFunction                 -- ^ The resulting flat function
54
55 flattenFunction _ (Rec _) = error "Recursive binders not supported"
56 flattenFunction hsfunc bind@(NonRec var expr) =
57   FlatFunction args res defs sigs
58   where
59     init_state        = ([], [], 0)
60     (fres, end_state) = State.runState (flattenTopExpr hsfunc expr) init_state
61     (defs, sigs, _)   = end_state
62     (args, res)       = fres
63
64 flattenTopExpr ::
65   HsFunction
66   -> CoreExpr
67   -> FlattenState ([SignalMap], SignalMap)
68
69 flattenTopExpr hsfunc expr = do
70   -- Flatten the expression
71   (args, res) <- flattenExpr [] expr
72   
73   -- Join the signal ids and uses together
74   let zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
75   let zipped_res = zipValueMaps res (hsFuncRes hsfunc)
76   -- Set the signal uses for each argument / result, possibly updating
77   -- argument or result signals.
78   args' <- mapM (Traversable.mapM $ hsUseToSigUse args_use) zipped_args
79   res' <- Traversable.mapM (hsUseToSigUse res_use) zipped_res
80   return (args', res')
81   where
82     args_use Port = SigPortIn
83     args_use (State n) = SigStateOld n
84     res_use Port = SigPortOut
85     res_use (State n) = SigStateNew n
86
87
88 hsUseToSigUse :: 
89   (HsValueUse -> SigUse)      -- ^ A function to actually map the use value
90   -> (SignalId, HsValueUse)   -- ^ The signal to look at and its use
91   -> FlattenState SignalId    -- ^ The resulting signal. This is probably the
92                               --   same as the input, but it could be different.
93 hsUseToSigUse f (id, use) = do
94   info <- getSignalInfo id
95   id' <- case sigUse info of 
96     -- Internal signals can be marked as different uses freely.
97     SigInternal -> do
98       return id
99     -- Signals that already have another use, must be duplicated before
100     -- marking. This prevents signals mapping to the same input or output
101     -- port or state variables and ports overlapping, etc.
102     otherwise -> do
103       duplicateSignal id
104   setSignalInfo id' (info { sigUse = f use})
105   return id'
106
107 -- | Duplicate the given signal, assigning its value to the new signal.
108 --   Returns the new signal id.
109 duplicateSignal :: SignalId -> FlattenState SignalId
110 duplicateSignal id = do
111   -- Find the type of the original signal
112   info <- getSignalInfo id
113   let ty = sigTy info
114   -- Generate a new signal (which is SigInternal for now, that will be
115   -- sorted out later on).
116   id' <- genSignalId SigInternal ty
117   -- Assign the old signal to the new signal
118   addDef $ UncondDef id id'
119   -- Replace the signal with the new signal
120   return id'
121         
122 flattenExpr ::
123   BindMap
124   -> CoreExpr
125   -> FlattenState ([SignalMap], SignalMap)
126
127 flattenExpr binds lam@(Lam b expr) = do
128   -- Find the type of the binder
129   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
130   -- Create signal names for the binder
131   defs <- genSignals arg_ty
132   let binds' = (b, Left defs):binds
133   (args, res) <- flattenExpr binds' expr
134   return (defs : args, res)
135
136 flattenExpr binds (Var id) =
137   case bind of
138     Left sig_use -> return ([], sig_use)
139     Right _ -> error "Higher order functions not supported."
140   where
141     bind = Maybe.fromMaybe
142       (error $ "Argument " ++ Name.getOccString id ++ "is unknown")
143       (lookup id binds)
144
145 flattenExpr binds app@(App _ _) = do
146   -- Is this a data constructor application?
147   case CoreUtils.exprIsConApp_maybe app of
148     -- Is this a tuple construction?
149     Just (dc, args) -> if DataCon.isTupleCon dc 
150       then
151         flattenBuildTupleExpr binds (dataConAppArgs dc args)
152       else
153         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
154     otherwise ->
155       -- Normal function application
156       let ((Var f), args) = collectArgs app in
157       flattenApplicationExpr binds (CoreUtils.exprType app) f args
158   where
159     flattenBuildTupleExpr binds args = do
160       -- Flatten each of our args
161       flat_args <- (State.mapM (flattenExpr binds) args)
162       -- Check and split each of the arguments
163       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
164       let res = Tuple arg_ress
165       return ([], res)
166
167     -- | Flatten a normal application expression
168     flattenApplicationExpr binds ty f args = do
169       -- Find the function to call
170       let func = appToHsFunction ty f args
171       -- Flatten each of our args
172       flat_args <- (State.mapM (flattenExpr binds) args)
173       -- Check and split each of the arguments
174       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
175       -- Generate signals for our result
176       res <- genSignals ty
177       -- Create the function application
178       let app = FApp {
179         appFunc = func,
180         appArgs = arg_ress,
181         appRes  = res
182       }
183       addDef app
184       return ([], res)
185     -- | Check a flattened expression to see if it is valid to use as a
186     --   function argument. The first argument is the original expression for
187     --   use in the error message.
188     checkArg arg flat =
189       let (args, res) = flat in
190       if not (null args)
191         then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
192         else flat 
193
194 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
195   (b_args, b_res) <- flattenExpr binds bexpr
196   if not (null b_args)
197     then
198       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
199     else
200       let binds' = (b, Left b_res) : binds in
201       flattenExpr binds' expr
202
203 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
204
205 flattenExpr binds expr@(Case (Var v) b _ alts) =
206   case alts of
207     [alt] -> flattenSingleAltCaseExpr binds v b alt
208     otherwise -> error $ "Multiple alternative case expression not supported: " ++ (showSDoc $ ppr expr)
209   where
210     flattenSingleAltCaseExpr ::
211       BindMap
212                                 -- A list of bindings in effect
213       -> Var.Var                -- The scrutinee
214       -> CoreBndr               -- The binder to bind the scrutinee to
215       -> CoreAlt                -- The single alternative
216       -> FlattenState ( [SignalMap], SignalMap)
217                                            -- See expandExpr
218     flattenSingleAltCaseExpr binds v b alt@(DataAlt datacon, bind_vars, expr) =
219       if not (DataCon.isTupleCon datacon) 
220         then
221           error $ "Dataconstructors other than tuple constructors not supported in case pattern of alternative: " ++ (showSDoc $ ppr alt)
222         else
223           let
224             -- Lookup the scrutinee (which must be a variable bound to a tuple) in
225             -- the existing bindings list and get the portname map for each of
226             -- it's elements.
227             Left (Tuple tuple_sigs) = Maybe.fromMaybe 
228               (error $ "Case expression uses unknown scrutinee " ++ Name.getOccString v)
229               (lookup v binds)
230             -- TODO include b in the binds list
231             -- Merge our existing binds with the new binds.
232             binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
233           in
234             -- Expand the expression with the new binds list
235             flattenExpr binds' expr
236     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
237
238
239       
240 flattenExpr _ _ = do
241   return ([], Tuple [])
242
243 appToHsFunction ::
244   Type.Type       -- ^ The return type
245   -> Var.Var      -- ^ The function to call
246   -> [CoreExpr]   -- ^ The function arguments
247   -> HsFunction   -- ^ The needed HsFunction
248
249 appToHsFunction ty f args =
250   HsFunction hsname hsargs hsres
251   where
252     hsname = Name.getOccString f
253     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
254     hsres  = useAsPort (mkHsValueMap ty)
255
256 -- | Filters non-state signals and returns the state number and signal id for
257 --   state values.
258 filterState ::
259   SignalId                       -- | The signal id to look at
260   -> HsValueUse                  -- | How is this signal used?
261   -> Maybe (StateId, SignalId )  -- | The state num and signal id, if this
262                                  --   signal was used as state
263
264 filterState id (State num) = 
265   Just (num, id)
266 filterState _ _ = Nothing
267
268 -- | Returns a list of the state number and signal id of all used-as-state
269 --   signals in the given maps.
270 stateList ::
271   HsUseMap
272   -> (SignalMap)
273   -> [(StateId, SignalId)]
274
275 stateList uses signals =
276     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
277   
278 -- | Returns pairs of signals that should be mapped to state in this function.
279 getOwnStates ::
280   HsFunction                      -- | The function to look at
281   -> FlatFunction                 -- | The function to look at
282   -> [(StateId, SignalInfo, SignalInfo)]   
283         -- | The state signals. The first is the state number, the second the
284         --   signal to assign the current state to, the last is the signal
285         --   that holds the new state.
286
287 getOwnStates hsfunc flatfunc =
288   [(old_num, old_info, new_info) 
289     | (old_num, old_info) <- args_states
290     , (new_num, new_info) <- res_states
291     , old_num == new_num]
292   where
293     sigs = flat_sigs flatfunc
294     -- Translate args and res to lists of (statenum, sigid)
295     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
296     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
297     -- Replace the second tuple element with the corresponding SignalInfo
298     args_states = map (Arrow.second $ signalInfo sigs) args
299     res_states = map (Arrow.second $ signalInfo sigs) res
300
301     
302 -- vim: set ts=8 sw=2 sts=2 expandtab: