Support binding the scrutinee of a Case expression.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import qualified Control.Monad as Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified TyCon
11 import qualified Literal
12 import qualified CoreUtils
13 import qualified TysWiredIn
14 import qualified IdInfo
15 import qualified Data.Traversable as Traversable
16 import qualified Data.Foldable as Foldable
17 import Control.Applicative
18 import Outputable ( showSDoc, ppr )
19 import qualified Control.Monad.Trans.State as State
20
21 import HsValueMap
22 import TranslatorTypes
23 import FlattenTypes
24 import CoreTools
25
26 -- Extract the arguments from a data constructor application (that is, the
27 -- normal args, leaving out the type args).
28 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
29 dataConAppArgs dc args =
30     drop tycount args
31   where
32     tycount = length $ DataCon.dataConAllTyVars dc
33
34 genSignals ::
35   Type.Type
36   -> FlattenState SignalMap
37
38 genSignals ty =
39   -- First generate a map with the right structure containing the types, and
40   -- generate signals for each of them.
41   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
42
43 -- | Marks a signal as the given SigUse, if its id is in the list of id's
44 --   given.
45 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
46 markSignals use ids (id, info) =
47   (id, info')
48   where
49     info' = if id `elem` ids then info { sigUse = use} else info
50
51 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
52 markSignal use id = markSignals use [id]
53
54 -- | Flatten a haskell function
55 flattenFunction ::
56   HsFunction                      -- ^ The function to flatten
57   -> CoreBind                     -- ^ The function value
58   -> FlatFunction                 -- ^ The resulting flat function
59
60 flattenFunction _ (Rec _) = error "Recursive binders not supported"
61 flattenFunction hsfunc bind@(NonRec var expr) =
62   FlatFunction args res defs sigs
63   where
64     init_state        = ([], [], 0)
65     (fres, end_state) = State.runState (flattenTopExpr hsfunc expr) init_state
66     (defs, sigs, _)   = end_state
67     (args, res)       = fres
68
69 flattenTopExpr ::
70   HsFunction
71   -> CoreExpr
72   -> FlattenState ([SignalMap], SignalMap)
73
74 flattenTopExpr hsfunc expr = do
75   -- Flatten the expression
76   (args, res) <- flattenExpr [] expr
77   
78   -- Join the signal ids and uses together
79   let zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
80   let zipped_res = zipValueMaps res (hsFuncRes hsfunc)
81   -- Set the signal uses for each argument / result, possibly updating
82   -- argument or result signals.
83   args' <- mapM (Traversable.mapM $ hsUseToSigUse args_use) zipped_args
84   res' <- Traversable.mapM (hsUseToSigUse res_use) zipped_res
85   return (args', res')
86   where
87     args_use Port = SigPortIn
88     args_use (State n) = SigStateOld n
89     res_use Port = SigPortOut
90     res_use (State n) = SigStateNew n
91
92
93 hsUseToSigUse :: 
94   (HsValueUse -> SigUse)      -- ^ A function to actually map the use value
95   -> (SignalId, HsValueUse)   -- ^ The signal to look at and its use
96   -> FlattenState SignalId    -- ^ The resulting signal. This is probably the
97                               --   same as the input, but it could be different.
98 hsUseToSigUse f (id, use) = do
99   info <- getSignalInfo id
100   id' <- case sigUse info of 
101     -- Internal signals can be marked as different uses freely.
102     SigInternal -> do
103       return id
104     -- Signals that already have another use, must be duplicated before
105     -- marking. This prevents signals mapping to the same input or output
106     -- port or state variables and ports overlapping, etc.
107     otherwise -> do
108       duplicateSignal id
109   setSignalInfo id' (info { sigUse = f use})
110   return id'
111
112 -- | Creates a new internal signal with the same type as the given signal
113 copySignal :: SignalId -> FlattenState SignalId
114 copySignal id = do
115   -- Find the type of the original signal
116   info <- getSignalInfo id
117   let ty = sigTy info
118   -- Generate a new signal (which is SigInternal for now, that will be
119   -- sorted out later on).
120   genSignalId SigInternal ty
121
122 -- | Duplicate the given signal, assigning its value to the new signal.
123 --   Returns the new signal id.
124 duplicateSignal :: SignalId -> FlattenState SignalId
125 duplicateSignal id = do
126   -- Create a new signal
127   id' <- copySignal id
128   -- Assign the old signal to the new signal
129   addDef $ UncondDef (Left id) id'
130   -- Replace the signal with the new signal
131   return id'
132         
133 flattenExpr ::
134   BindMap
135   -> CoreExpr
136   -> FlattenState ([SignalMap], SignalMap)
137
138 flattenExpr binds lam@(Lam b expr) = do
139   -- Find the type of the binder
140   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
141   -- Create signal names for the binder
142   defs <- genSignals arg_ty
143   -- Add name hints to the generated signals
144   let binder_name = Name.getOccString b
145   Traversable.mapM (addNameHint binder_name) defs
146   let binds' = (b, Left defs):binds
147   (args, res) <- flattenExpr binds' expr
148   return (defs : args, res)
149
150 flattenExpr binds var@(Var id) =
151   case Var.globalIdVarDetails id of
152     IdInfo.NotGlobalId ->
153       let 
154         bind = Maybe.fromMaybe
155           (error $ "Local value " ++ Name.getOccString id ++ " is unknown")
156           (lookup id binds) 
157       in
158         case bind of
159           Left sig_use -> return ([], sig_use)
160           Right _ -> error "Higher order functions not supported."
161     IdInfo.DataConWorkId datacon -> do
162       if DataCon.isTupleCon datacon && (null $ DataCon.dataConAllTyVars datacon)
163         then do
164           -- Empty tuple construction
165           return ([], Tuple [])
166         else do
167           lit <- dataConToLiteral datacon
168           let ty = CoreUtils.exprType var
169           sig_id <- genSignalId SigInternal ty
170           -- Add a name hint to the signal
171           addNameHint (Name.getOccString id) sig_id
172           addDef (UncondDef (Right $ Literal lit Nothing) sig_id)
173           return ([], Single sig_id)
174     IdInfo.VanillaGlobal ->
175       -- Treat references to globals as an application with zero elements
176       flattenApplicationExpr binds (CoreUtils.exprType var) id []
177     otherwise ->
178       error $ "Ids other than local vars and dataconstructors not supported: " ++ (showSDoc $ ppr id)
179
180 flattenExpr binds app@(App _ _) = do
181   -- Is this a data constructor application?
182   case CoreUtils.exprIsConApp_maybe app of
183     -- Is this a tuple construction?
184     Just (dc, args) -> if DataCon.isTupleCon dc 
185       then
186         flattenBuildTupleExpr binds (dataConAppArgs dc args)
187       else
188         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
189     otherwise ->
190       -- Normal function application
191       let ((Var f), args) = collectArgs app in
192       let fname = Name.getOccString f in
193       if fname == "fst" || fname == "snd" then do
194         (args', Tuple [a, b]) <- flattenExpr binds (last args)
195         return (args', if fname == "fst" then a else b)
196       else if fname == "patError" then do
197         -- This is essentially don't care, since the program will error out
198         -- here. We'll just define undriven signals here.
199         let (argtys, resty) = Type.splitFunTys $ CoreUtils.exprType app
200         args <- mapM genSignals argtys
201         res <- genSignals resty
202         mapM (Traversable.mapM (addNameHint "NC")) args
203         Traversable.mapM (addNameHint "NC") res
204         return (args, res)
205       else if fname == "==" then do
206         -- Flatten the last two arguments (this skips the type arguments)
207         ([], a) <- flattenExpr binds (last $ init args)
208         ([], b) <- flattenExpr binds (last args)
209         res <- mkEqComparisons a b
210         return ([], res)
211       else if fname == "fromInteger" then do
212         let [to_ty, to_dict, val] = args 
213         -- We assume this is an application of the GHC.Integer.smallInteger
214         -- function to a literal
215         let App smallint (Lit lit) = val
216         let (Literal.MachInt int) = lit
217         let ty = CoreUtils.exprType app
218         sig_id <- genSignalId SigInternal ty
219         -- TODO: fromInteger is defined for more types than just SizedWord
220         let len = sized_word_len ty
221         -- Use a to_unsigned to translate the number (a natural) to an unsiged
222         -- (array of bits)
223         let lit_str = "to_unsigned(" ++ (show int) ++ ", " ++ (show len) ++ ")"
224         -- Set the signal to our literal unconditionally, but add the type so
225         -- the literal will be typecast to the proper type.
226         addDef $ UncondDef (Right $ Literal lit_str (Just ty)) sig_id
227         return ([], Single sig_id)
228       else
229         flattenApplicationExpr binds (CoreUtils.exprType app) f args
230   where
231     mkEqComparisons :: SignalMap -> SignalMap -> FlattenState SignalMap
232     mkEqComparisons a b = do
233       let zipped = zipValueMaps a b
234       Traversable.mapM mkEqComparison zipped
235
236     mkEqComparison :: (SignalId, SignalId) -> FlattenState SignalId
237     mkEqComparison (a, b) = do
238       -- Generate a signal to hold our result
239       res <- genSignalId SigInternal TysWiredIn.boolTy
240       -- Add a name hint to the signal
241       addNameHint ("s" ++ show a ++ "_eq_s" ++ show b) res
242       addDef (UncondDef (Right $ Eq a b) res)
243       return res
244
245     flattenBuildTupleExpr binds args = do
246       -- Flatten each of our args
247       flat_args <- (mapM (flattenExpr binds) args)
248       -- Check and split each of the arguments
249       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
250       let res = Tuple arg_ress
251       return ([], res)
252
253 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
254   (b_args, b_res) <- flattenExpr binds bexpr
255   if not (null b_args)
256     then
257       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
258     else do
259       let binds' = (b, Left b_res) : binds
260       -- Add name hints to the generated signals
261       let binder_name = Name.getOccString b
262       Traversable.mapM (addNameHint binder_name) b_res
263       flattenExpr binds' expr
264
265 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
266
267 flattenExpr binds expr@(Case scrut b _ alts) = do
268   -- TODO: Special casing for higher order functions
269   -- Flatten the scrutinee
270   (_, res) <- flattenExpr binds scrut
271   -- Put the scrutinee in the BindMap
272   let binds' = (b, Left res) : binds
273   case alts of
274     [alt] -> flattenSingleAltCaseExpr binds' res b alt
275     -- Reverse the alternatives, so the __DEFAULT alternative ends up last
276     otherwise -> flattenMultipleAltCaseExpr binds' res b (reverse alts)
277   where
278     flattenSingleAltCaseExpr ::
279       BindMap
280                                 -- A list of bindings in effect
281       -> SignalMap              -- The scrutinee
282       -> CoreBndr               -- The binder to bind the scrutinee to
283       -> CoreAlt                -- The single alternative
284       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
285
286     flattenSingleAltCaseExpr binds scrut b alt@(DataAlt datacon, bind_vars, expr) =
287       if DataCon.isTupleCon datacon
288         then do
289           -- Unpack the scrutinee (which must be a variable bound to a tuple) in
290           -- the existing bindings list and get the portname map for each of
291           -- it's elements.
292           let Tuple tuple_sigs = scrut
293           -- Add name hints to the returned signals
294           let binder_name = Name.getOccString b
295           Monad.zipWithM (\name  sigs -> Traversable.mapM (addNameHint $ Name.getOccString name) sigs) bind_vars tuple_sigs
296           -- Merge our existing binds with the new binds.
297           let binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
298           -- Expand the expression with the new binds list
299           flattenExpr binds' expr
300         else
301           if null bind_vars
302             then
303               -- DataAlts without arguments don't need processing
304               -- (flattenMultipleAltCaseExpr will have done this already).
305               flattenExpr binds expr
306             else
307               error $ "Dataconstructors other than tuple constructors cannot have binder arguments in case pattern of alternative: " ++ (showSDoc $ ppr alt)
308
309     flattenSingleAltCaseExpr binds _ _ alt@(DEFAULT, [], expr) =
310       flattenExpr binds expr
311       
312     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
313
314     flattenMultipleAltCaseExpr ::
315       BindMap
316                                 -- A list of bindings in effect
317       -> SignalMap              -- The scrutinee
318       -> CoreBndr               -- The binder to bind the scrutinee to
319       -> [CoreAlt]              -- The alternatives
320       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
321
322     flattenMultipleAltCaseExpr binds scrut b (a:a':alts) = do
323       (args, res) <- flattenSingleAltCaseExpr binds scrut b a
324       (args', res') <- flattenMultipleAltCaseExpr binds scrut b (a':alts)
325       case a of
326         (DataAlt datacon, bind_vars, expr) -> do
327           lit <- dataConToLiteral datacon
328           -- The scrutinee must be a single signal
329           let Single sig = scrut
330           -- Create a signal that contains a boolean
331           boolsigid <- genSignalId SigInternal TysWiredIn.boolTy
332           addNameHint ("s" ++ show sig ++ "_eq_" ++ lit) boolsigid
333           let expr = EqLit sig lit
334           addDef (UncondDef (Right expr) boolsigid)
335           -- Create conditional assignments of either args/res or
336           -- args'/res based on boolsigid, and return the result.
337           -- TODO: It seems this adds the name hint twice?
338           our_args <- Monad.zipWithM (mkConditionals boolsigid) args args'
339           our_res  <- mkConditionals boolsigid res res'
340           return (our_args, our_res)
341         otherwise ->
342           error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr a)
343       where
344         -- Select either the first or second signal map depending on the value
345         -- of the first argument (True == first map, False == second map)
346         mkConditionals :: SignalId -> SignalMap -> SignalMap -> FlattenState SignalMap
347         mkConditionals boolsigid true false = do
348           let zipped = zipValueMaps true false
349           Traversable.mapM (mkConditional boolsigid) zipped
350
351         mkConditional :: SignalId -> (SignalId, SignalId) -> FlattenState SignalId
352         mkConditional boolsigid (true, false) = do
353           -- Create a new signal (true and false should be identically typed,
354           -- so it doesn't matter which one we copy).
355           res <- copySignal true
356           addDef (CondDef boolsigid true false res)
357           return res
358
359     flattenMultipleAltCaseExpr binds scrut b (a:alts) =
360       flattenSingleAltCaseExpr binds scrut b a
361
362 flattenExpr _ expr = do
363   error $ "Unsupported expression: " ++ (showSDoc $ ppr expr)
364
365 -- | Flatten a normal application expression
366 flattenApplicationExpr binds ty f args = do
367   -- Find the function to call
368   let func = appToHsFunction ty f args
369   -- Flatten each of our args
370   flat_args <- (mapM (flattenExpr binds) args)
371   -- Check and split each of the arguments
372   let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
373   -- Generate signals for our result
374   res <- genSignals ty
375   -- Add name hints to the generated signals
376   let resname = Name.getOccString f ++ "_res"
377   Traversable.mapM (addNameHint resname) res
378   -- Create the function application
379   let app = FApp {
380     appFunc = func,
381     appArgs = arg_ress,
382     appRes  = res
383   }
384   addDef app
385   return ([], res)
386 -- | Check a flattened expression to see if it is valid to use as a
387 --   function argument. The first argument is the original expression for
388 --   use in the error message.
389 checkArg arg flat =
390   let (args, res) = flat in
391   if not (null args)
392     then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
393     else flat 
394
395 -- | Translates a dataconstructor without arguments to the corresponding
396 --   literal.
397 dataConToLiteral :: DataCon.DataCon -> FlattenState String
398 dataConToLiteral datacon = do
399   let tycon = DataCon.dataConTyCon datacon
400   let tyname = TyCon.tyConName tycon
401   case Name.getOccString tyname of
402     -- TODO: Do something more robust than string matching
403     "Bit"      -> do
404       let dcname = DataCon.dataConName datacon
405       let lit = case Name.getOccString dcname of "High" -> "'1'"; "Low" -> "'0'"
406       return lit
407     "Bool" -> do
408       let dcname = DataCon.dataConName datacon
409       let lit = case Name.getOccString dcname of "True" -> "true"; "False" -> "false"
410       return lit
411     otherwise ->
412       error $ "Literals of type " ++ (Name.getOccString tyname) ++ " not supported."
413
414 appToHsFunction ::
415   Type.Type       -- ^ The return type
416   -> Var.Var      -- ^ The function to call
417   -> [CoreExpr]   -- ^ The function arguments
418   -> HsFunction   -- ^ The needed HsFunction
419
420 appToHsFunction ty f args =
421   HsFunction hsname hsargs hsres
422   where
423     hsname = Name.getOccString f
424     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
425     hsres  = useAsPort (mkHsValueMap ty)
426
427 -- | Filters non-state signals and returns the state number and signal id for
428 --   state values.
429 filterState ::
430   SignalId                       -- | The signal id to look at
431   -> HsValueUse                  -- | How is this signal used?
432   -> Maybe (StateId, SignalId )  -- | The state num and signal id, if this
433                                  --   signal was used as state
434
435 filterState id (State num) = 
436   Just (num, id)
437 filterState _ _ = Nothing
438
439 -- | Returns a list of the state number and signal id of all used-as-state
440 --   signals in the given maps.
441 stateList ::
442   HsUseMap
443   -> (SignalMap)
444   -> [(StateId, SignalId)]
445
446 stateList uses signals =
447     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
448   
449 -- vim: set ts=8 sw=2 sts=2 expandtab: