Add name hints to various signals generated.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import qualified Control.Monad as Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified TyCon
11 import qualified CoreUtils
12 import qualified TysWiredIn
13 import qualified IdInfo
14 import qualified Data.Traversable as Traversable
15 import qualified Data.Foldable as Foldable
16 import Control.Applicative
17 import Outputable ( showSDoc, ppr )
18 import qualified Control.Monad.State as State
19
20 import HsValueMap
21 import TranslatorTypes
22 import FlattenTypes
23
24 -- Extract the arguments from a data constructor application (that is, the
25 -- normal args, leaving out the type args).
26 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
27 dataConAppArgs dc args =
28     drop tycount args
29   where
30     tycount = length $ DataCon.dataConAllTyVars dc
31
32 genSignals ::
33   Type.Type
34   -> FlattenState SignalMap
35
36 genSignals ty =
37   -- First generate a map with the right structure containing the types, and
38   -- generate signals for each of them.
39   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
40
41 -- | Marks a signal as the given SigUse, if its id is in the list of id's
42 --   given.
43 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
44 markSignals use ids (id, info) =
45   (id, info')
46   where
47     info' = if id `elem` ids then info { sigUse = use} else info
48
49 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
50 markSignal use id = markSignals use [id]
51
52 -- | Flatten a haskell function
53 flattenFunction ::
54   HsFunction                      -- ^ The function to flatten
55   -> CoreBind                     -- ^ The function value
56   -> FlatFunction                 -- ^ The resulting flat function
57
58 flattenFunction _ (Rec _) = error "Recursive binders not supported"
59 flattenFunction hsfunc bind@(NonRec var expr) =
60   FlatFunction args res defs sigs
61   where
62     init_state        = ([], [], 0)
63     (fres, end_state) = State.runState (flattenTopExpr hsfunc expr) init_state
64     (defs, sigs, _)   = end_state
65     (args, res)       = fres
66
67 flattenTopExpr ::
68   HsFunction
69   -> CoreExpr
70   -> FlattenState ([SignalMap], SignalMap)
71
72 flattenTopExpr hsfunc expr = do
73   -- Flatten the expression
74   (args, res) <- flattenExpr [] expr
75   
76   -- Join the signal ids and uses together
77   let zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
78   let zipped_res = zipValueMaps res (hsFuncRes hsfunc)
79   -- Set the signal uses for each argument / result, possibly updating
80   -- argument or result signals.
81   args' <- mapM (Traversable.mapM $ hsUseToSigUse args_use) zipped_args
82   res' <- Traversable.mapM (hsUseToSigUse res_use) zipped_res
83   return (args', res')
84   where
85     args_use Port = SigPortIn
86     args_use (State n) = SigStateOld n
87     res_use Port = SigPortOut
88     res_use (State n) = SigStateNew n
89
90
91 hsUseToSigUse :: 
92   (HsValueUse -> SigUse)      -- ^ A function to actually map the use value
93   -> (SignalId, HsValueUse)   -- ^ The signal to look at and its use
94   -> FlattenState SignalId    -- ^ The resulting signal. This is probably the
95                               --   same as the input, but it could be different.
96 hsUseToSigUse f (id, use) = do
97   info <- getSignalInfo id
98   id' <- case sigUse info of 
99     -- Internal signals can be marked as different uses freely.
100     SigInternal -> do
101       return id
102     -- Signals that already have another use, must be duplicated before
103     -- marking. This prevents signals mapping to the same input or output
104     -- port or state variables and ports overlapping, etc.
105     otherwise -> do
106       duplicateSignal id
107   setSignalInfo id' (info { sigUse = f use})
108   return id'
109
110 -- | Creates a new internal signal with the same type as the given signal
111 copySignal :: SignalId -> FlattenState SignalId
112 copySignal id = do
113   -- Find the type of the original signal
114   info <- getSignalInfo id
115   let ty = sigTy info
116   -- Generate a new signal (which is SigInternal for now, that will be
117   -- sorted out later on).
118   genSignalId SigInternal ty
119
120 -- | Duplicate the given signal, assigning its value to the new signal.
121 --   Returns the new signal id.
122 duplicateSignal :: SignalId -> FlattenState SignalId
123 duplicateSignal id = do
124   -- Create a new signal
125   id' <- copySignal id
126   -- Assign the old signal to the new signal
127   addDef $ UncondDef (Left id) id'
128   -- Replace the signal with the new signal
129   return id'
130         
131 flattenExpr ::
132   BindMap
133   -> CoreExpr
134   -> FlattenState ([SignalMap], SignalMap)
135
136 flattenExpr binds lam@(Lam b expr) = do
137   -- Find the type of the binder
138   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
139   -- Create signal names for the binder
140   defs <- genSignals arg_ty
141   -- Add name hints to the generated signals
142   let binder_name = Name.getOccString b
143   Traversable.mapM (addNameHint binder_name) defs
144   let binds' = (b, Left defs):binds
145   (args, res) <- flattenExpr binds' expr
146   return (defs : args, res)
147
148 flattenExpr binds var@(Var id) =
149   case Var.globalIdVarDetails id of
150     IdInfo.NotGlobalId ->
151       let 
152         bind = Maybe.fromMaybe
153           (error $ "Local value " ++ Name.getOccString id ++ " is unknown")
154           (lookup id binds) 
155       in
156         case bind of
157           Left sig_use -> return ([], sig_use)
158           Right _ -> error "Higher order functions not supported."
159     IdInfo.DataConWorkId datacon -> do
160       lit <- dataConToLiteral datacon
161       let ty = CoreUtils.exprType var
162       sig_id <- genSignalId SigInternal ty
163       -- Add a name hint to the signal
164       addNameHint (Name.getOccString id) sig_id
165       addDef (UncondDef (Right $ Literal lit) sig_id)
166       return ([], Single sig_id)
167     otherwise ->
168       error $ "Ids other than local vars and dataconstructors not supported: " ++ (showSDoc $ ppr id)
169
170 flattenExpr binds app@(App _ _) = do
171   -- Is this a data constructor application?
172   case CoreUtils.exprIsConApp_maybe app of
173     -- Is this a tuple construction?
174     Just (dc, args) -> if DataCon.isTupleCon dc 
175       then
176         flattenBuildTupleExpr binds (dataConAppArgs dc args)
177       else
178         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
179     otherwise ->
180       -- Normal function application
181       let ((Var f), args) = collectArgs app in
182       let fname = Name.getOccString f in
183       if fname == "fst" || fname == "snd" then do
184         (args', Tuple [a, b]) <- flattenExpr binds (last args)
185         return (args', if fname == "fst" then a else b)
186       else if fname == "patError" then do
187         -- This is essentially don't care, since the program will error out
188         -- here. We'll just define undriven signals here.
189         let (argtys, resty) = Type.splitFunTys $ CoreUtils.exprType app
190         args <- mapM genSignals argtys
191         res <- genSignals resty
192         mapM (Traversable.mapM (addNameHint "NC")) args
193         Traversable.mapM (addNameHint "NC") res
194         return (args, res)
195       else if fname == "==" then do
196         -- Flatten the last two arguments (this skips the type arguments)
197         ([], a) <- flattenExpr binds (last $ init args)
198         ([], b) <- flattenExpr binds (last args)
199         res <- mkEqComparisons a b
200         return ([], res)
201       else
202         flattenApplicationExpr binds (CoreUtils.exprType app) f args
203   where
204     mkEqComparisons :: SignalMap -> SignalMap -> FlattenState SignalMap
205     mkEqComparisons a b = do
206       let zipped = zipValueMaps a b
207       Traversable.mapM mkEqComparison zipped
208
209     mkEqComparison :: (SignalId, SignalId) -> FlattenState SignalId
210     mkEqComparison (a, b) = do
211       -- Generate a signal to hold our result
212       res <- genSignalId SigInternal TysWiredIn.boolTy
213       -- Add a name hint to the signal
214       addNameHint ("s" ++ show a ++ "_eq_s" ++ show b) res
215       addDef (UncondDef (Right $ Eq a b) res)
216       return res
217
218     flattenBuildTupleExpr binds args = do
219       -- Flatten each of our args
220       flat_args <- (State.mapM (flattenExpr binds) args)
221       -- Check and split each of the arguments
222       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
223       let res = Tuple arg_ress
224       return ([], res)
225
226     -- | Flatten a normal application expression
227     flattenApplicationExpr binds ty f args = do
228       -- Find the function to call
229       let func = appToHsFunction ty f args
230       -- Flatten each of our args
231       flat_args <- (State.mapM (flattenExpr binds) args)
232       -- Check and split each of the arguments
233       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
234       -- Generate signals for our result
235       res <- genSignals ty
236       -- Add name hints to the generated signals
237       let resname = Name.getOccString f ++ "_res"
238       Traversable.mapM (addNameHint resname) res
239       -- Create the function application
240       let app = FApp {
241         appFunc = func,
242         appArgs = arg_ress,
243         appRes  = res
244       }
245       addDef app
246       return ([], res)
247     -- | Check a flattened expression to see if it is valid to use as a
248     --   function argument. The first argument is the original expression for
249     --   use in the error message.
250     checkArg arg flat =
251       let (args, res) = flat in
252       if not (null args)
253         then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
254         else flat 
255
256 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
257   (b_args, b_res) <- flattenExpr binds bexpr
258   if not (null b_args)
259     then
260       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
261     else do
262       let binds' = (b, Left b_res) : binds
263       -- Add name hints to the generated signals
264       let binder_name = Name.getOccString b
265       Traversable.mapM (addNameHint binder_name) b_res
266       flattenExpr binds' expr
267
268 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
269
270 flattenExpr binds expr@(Case scrut b _ alts) = do
271   -- TODO: Special casing for higher order functions
272   -- Flatten the scrutinee
273   (_, res) <- flattenExpr binds scrut
274   case alts of
275     -- TODO include b in the binds list
276     [alt] -> flattenSingleAltCaseExpr binds res b alt
277     -- Reverse the alternatives, so the __DEFAULT alternative ends up last
278     otherwise -> flattenMultipleAltCaseExpr binds res b (reverse alts)
279   where
280     flattenSingleAltCaseExpr ::
281       BindMap
282                                 -- A list of bindings in effect
283       -> SignalMap              -- The scrutinee
284       -> CoreBndr               -- The binder to bind the scrutinee to
285       -> CoreAlt                -- The single alternative
286       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
287
288     flattenSingleAltCaseExpr binds scrut b alt@(DataAlt datacon, bind_vars, expr) =
289       if DataCon.isTupleCon datacon
290         then do
291           -- Unpack the scrutinee (which must be a variable bound to a tuple) in
292           -- the existing bindings list and get the portname map for each of
293           -- it's elements.
294           let Tuple tuple_sigs = scrut
295           -- Add name hints to the returned signals
296           let binder_name = Name.getOccString b
297           Monad.zipWithM (\name  sigs -> Traversable.mapM (addNameHint $ Name.getOccString name) sigs) bind_vars tuple_sigs
298           -- Merge our existing binds with the new binds.
299           let binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
300           -- Expand the expression with the new binds list
301           flattenExpr binds' expr
302         else
303           if null bind_vars
304             then
305               -- DataAlts without arguments don't need processing
306               -- (flattenMultipleAltCaseExpr will have done this already).
307               flattenExpr binds expr
308             else
309               error $ "Dataconstructors other than tuple constructors cannot have binder arguments in case pattern of alternative: " ++ (showSDoc $ ppr alt)
310
311     flattenSingleAltCaseExpr binds _ _ alt@(DEFAULT, [], expr) =
312       flattenExpr binds expr
313       
314     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
315
316     flattenMultipleAltCaseExpr ::
317       BindMap
318                                 -- A list of bindings in effect
319       -> SignalMap              -- The scrutinee
320       -> CoreBndr               -- The binder to bind the scrutinee to
321       -> [CoreAlt]              -- The alternatives
322       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
323
324     flattenMultipleAltCaseExpr binds scrut b (a:a':alts) = do
325       (args, res) <- flattenSingleAltCaseExpr binds scrut b a
326       (args', res') <- flattenMultipleAltCaseExpr binds scrut b (a':alts)
327       case a of
328         (DataAlt datacon, bind_vars, expr) -> do
329           if isDontCare datacon 
330             then do
331               -- Completely skip the dontcare cases
332               return (args', res')
333             else do
334               lit <- dataConToLiteral datacon
335               -- The scrutinee must be a single signal
336               let Single sig = scrut
337               -- Create a signal that contains a boolean
338               boolsigid <- genSignalId SigInternal TysWiredIn.boolTy
339               addNameHint ("s" ++ show sig ++ "_eq_" ++ lit) boolsigid
340               let expr = EqLit sig lit
341               addDef (UncondDef (Right expr) boolsigid)
342               -- Create conditional assignments of either args/res or
343               -- args'/res based on boolsigid, and return the result.
344               -- TODO: It seems this adds the name hint twice?
345               our_args <- Monad.zipWithM (mkConditionals boolsigid) args args'
346               our_res  <- mkConditionals boolsigid res res'
347               return (our_args, our_res)
348         otherwise ->
349           error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr a)
350       where
351         -- Select either the first or second signal map depending on the value
352         -- of the first argument (True == first map, False == second map)
353         mkConditionals :: SignalId -> SignalMap -> SignalMap -> FlattenState SignalMap
354         mkConditionals boolsigid true false = do
355           let zipped = zipValueMaps true false
356           Traversable.mapM (mkConditional boolsigid) zipped
357
358         mkConditional :: SignalId -> (SignalId, SignalId) -> FlattenState SignalId
359         mkConditional boolsigid (true, false) = do
360           -- Create a new signal (true and false should be identically typed,
361           -- so it doesn't matter which one we copy).
362           res <- copySignal true
363           addDef (CondDef boolsigid true false res)
364           return res
365
366     flattenMultipleAltCaseExpr binds scrut b (a:alts) =
367       flattenSingleAltCaseExpr binds scrut b a
368
369 flattenExpr _ expr = do
370   error $ "Unsupported expression: " ++ (showSDoc $ ppr expr)
371
372 -- | Is the given data constructor a dontcare?
373 isDontCare :: DataCon.DataCon -> Bool
374 isDontCare datacon =
375   case Name.getOccString tyname of
376     -- TODO: Do something more robust than string matching
377     "Bit" ->
378       Name.getOccString dcname  == "DontCare"
379     otherwise ->
380       False
381   where
382     tycon = DataCon.dataConTyCon datacon
383     tyname = TyCon.tyConName tycon
384     dcname = DataCon.dataConName datacon
385
386 -- | Translates a dataconstructor without arguments to the corresponding
387 --   literal.
388 dataConToLiteral :: DataCon.DataCon -> FlattenState String
389 dataConToLiteral datacon = do
390   let tycon = DataCon.dataConTyCon datacon
391   let tyname = TyCon.tyConName tycon
392   case Name.getOccString tyname of
393     -- TODO: Do something more robust than string matching
394     "Bit"      -> do
395       let dcname = DataCon.dataConName datacon
396       let lit = case Name.getOccString dcname of "High" -> "'1'"; "Low" -> "'0'"; "DontCare" -> "'-'"
397       return lit
398     "Bool" -> do
399       let dcname = DataCon.dataConName datacon
400       let lit = case Name.getOccString dcname of "True" -> "true"; "False" -> "false"
401       return lit
402     otherwise ->
403       error $ "Literals of type " ++ (Name.getOccString tyname) ++ " not supported."
404
405 appToHsFunction ::
406   Type.Type       -- ^ The return type
407   -> Var.Var      -- ^ The function to call
408   -> [CoreExpr]   -- ^ The function arguments
409   -> HsFunction   -- ^ The needed HsFunction
410
411 appToHsFunction ty f args =
412   HsFunction hsname hsargs hsres
413   where
414     hsname = Name.getOccString f
415     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
416     hsres  = useAsPort (mkHsValueMap ty)
417
418 -- | Filters non-state signals and returns the state number and signal id for
419 --   state values.
420 filterState ::
421   SignalId                       -- | The signal id to look at
422   -> HsValueUse                  -- | How is this signal used?
423   -> Maybe (StateId, SignalId )  -- | The state num and signal id, if this
424                                  --   signal was used as state
425
426 filterState id (State num) = 
427   Just (num, id)
428 filterState _ _ = Nothing
429
430 -- | Returns a list of the state number and signal id of all used-as-state
431 --   signals in the given maps.
432 stateList ::
433   HsUseMap
434   -> (SignalMap)
435   -> [(StateId, SignalId)]
436
437 stateList uses signals =
438     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
439   
440 -- | Returns pairs of signals that should be mapped to state in this function.
441 getOwnStates ::
442   HsFunction                      -- | The function to look at
443   -> FlatFunction                 -- | The function to look at
444   -> [(StateId, SignalInfo, SignalInfo)]   
445         -- | The state signals. The first is the state number, the second the
446         --   signal to assign the current state to, the last is the signal
447         --   that holds the new state.
448
449 getOwnStates hsfunc flatfunc =
450   [(old_num, old_info, new_info) 
451     | (old_num, old_info) <- args_states
452     , (new_num, new_info) <- res_states
453     , old_num == new_num]
454   where
455     sigs = flat_sigs flatfunc
456     -- Translate args and res to lists of (statenum, sigid)
457     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
458     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
459     -- Replace the second tuple element with the corresponding SignalInfo
460     args_states = map (Arrow.second $ signalInfo sigs) args
461     res_states = map (Arrow.second $ signalInfo sigs) res
462
463     
464 -- vim: set ts=8 sw=2 sts=2 expandtab: