Replace fst with a lambda in a map.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import qualified Control.Monad as Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified TyCon
11 import qualified Literal
12 import qualified CoreUtils
13 import qualified TysWiredIn
14 import qualified IdInfo
15 import qualified Data.Traversable as Traversable
16 import qualified Data.Foldable as Foldable
17 import Control.Applicative
18 import Outputable ( showSDoc, ppr )
19 import qualified Control.Monad.Trans.State as State
20
21 import HsValueMap
22 import TranslatorTypes
23 import FlattenTypes
24 import CoreTools
25
26 -- Extract the arguments from a data constructor application (that is, the
27 -- normal args, leaving out the type args).
28 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
29 dataConAppArgs dc args =
30     drop tycount args
31   where
32     tycount = length $ DataCon.dataConAllTyVars dc
33
34 genSignals ::
35   Type.Type
36   -> FlattenState SignalMap
37
38 genSignals ty =
39   -- First generate a map with the right structure containing the types, and
40   -- generate signals for each of them.
41   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
42
43 -- | Marks a signal as the given SigUse, if its id is in the list of id's
44 --   given.
45 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
46 markSignals use ids (id, info) =
47   (id, info')
48   where
49     info' = if id `elem` ids then info { sigUse = use} else info
50
51 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
52 markSignal use id = markSignals use [id]
53
54 -- | Flatten a haskell function
55 flattenFunction ::
56   HsFunction                      -- ^ The function to flatten
57   -> (CoreBndr, CoreExpr)         -- ^ The function value
58   -> FlatFunction                 -- ^ The resulting flat function
59
60 flattenFunction hsfunc (var, expr) =
61   FlatFunction args res defs sigs
62   where
63     init_state        = ([], [], 0)
64     (fres, end_state) = State.runState (flattenTopExpr hsfunc expr) init_state
65     (defs, sigs, _)   = end_state
66     (args, res)       = fres
67
68 flattenTopExpr ::
69   HsFunction
70   -> CoreExpr
71   -> FlattenState ([SignalMap], SignalMap)
72
73 flattenTopExpr hsfunc expr = do
74   -- Flatten the expression
75   (args, res) <- flattenExpr [] expr
76   
77   -- Join the signal ids and uses together
78   let zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
79   let zipped_res = zipValueMaps res (hsFuncRes hsfunc)
80   -- Set the signal uses for each argument / result, possibly updating
81   -- argument or result signals.
82   args' <- mapM (Traversable.mapM $ hsUseToSigUse args_use) zipped_args
83   res' <- Traversable.mapM (hsUseToSigUse res_use) zipped_res
84   return (args', res')
85   where
86     args_use Port = SigPortIn
87     args_use (State n) = SigStateOld n
88     res_use Port = SigPortOut
89     res_use (State n) = SigStateNew n
90
91
92 hsUseToSigUse :: 
93   (HsValueUse -> SigUse)      -- ^ A function to actually map the use value
94   -> (SignalId, HsValueUse)   -- ^ The signal to look at and its use
95   -> FlattenState SignalId    -- ^ The resulting signal. This is probably the
96                               --   same as the input, but it could be different.
97 hsUseToSigUse f (id, use) = do
98   info <- getSignalInfo id
99   id' <- case sigUse info of 
100     -- Internal signals can be marked as different uses freely.
101     SigInternal -> do
102       return id
103     -- Signals that already have another use, must be duplicated before
104     -- marking. This prevents signals mapping to the same input or output
105     -- port or state variables and ports overlapping, etc.
106     otherwise -> do
107       duplicateSignal id
108   setSignalInfo id' (info { sigUse = f use})
109   return id'
110
111 -- | Creates a new internal signal with the same type as the given signal
112 copySignal :: SignalId -> FlattenState SignalId
113 copySignal id = do
114   -- Find the type of the original signal
115   info <- getSignalInfo id
116   let ty = sigTy info
117   -- Generate a new signal (which is SigInternal for now, that will be
118   -- sorted out later on).
119   genSignalId SigInternal ty
120
121 -- | Duplicate the given signal, assigning its value to the new signal.
122 --   Returns the new signal id.
123 duplicateSignal :: SignalId -> FlattenState SignalId
124 duplicateSignal id = do
125   -- Create a new signal
126   id' <- copySignal id
127   -- Assign the old signal to the new signal
128   addDef $ UncondDef (Left id) id'
129   -- Replace the signal with the new signal
130   return id'
131         
132 flattenExpr ::
133   BindMap
134   -> CoreExpr
135   -> FlattenState ([SignalMap], SignalMap)
136
137 flattenExpr binds lam@(Lam b expr) = do
138   -- Find the type of the binder
139   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
140   -- Create signal names for the binder
141   defs <- genSignals arg_ty
142   -- Add name hints to the generated signals
143   let binder_name = Name.getOccString b
144   Traversable.mapM (addNameHint binder_name) defs
145   let binds' = (b, Left defs):binds
146   (args, res) <- flattenExpr binds' expr
147   return (defs : args, res)
148
149 flattenExpr binds var@(Var id) =
150   case Var.globalIdVarDetails id of
151     IdInfo.NotGlobalId ->
152       let 
153         bind = Maybe.fromMaybe
154           (error $ "Local value " ++ Name.getOccString id ++ " is unknown")
155           (lookup id binds) 
156       in
157         case bind of
158           Left sig_use -> return ([], sig_use)
159           Right _ -> error "Higher order functions not supported."
160     IdInfo.DataConWorkId datacon -> do
161       if DataCon.isTupleCon datacon && (null $ DataCon.dataConAllTyVars datacon)
162         then do
163           -- Empty tuple construction
164           return ([], Tuple [])
165         else do
166           lit <- dataConToLiteral datacon
167           let ty = CoreUtils.exprType var
168           sig_id <- genSignalId SigInternal ty
169           -- Add a name hint to the signal
170           addNameHint (Name.getOccString id) sig_id
171           addDef (UncondDef (Right $ Literal lit Nothing) sig_id)
172           return ([], Single sig_id)
173     IdInfo.VanillaGlobal ->
174       -- Treat references to globals as an application with zero elements
175       flattenApplicationExpr binds (CoreUtils.exprType var) id []
176     otherwise ->
177       error $ "Ids other than local vars and dataconstructors not supported: " ++ (showSDoc $ ppr id)
178
179 flattenExpr binds app@(App _ _) = do
180   -- Is this a data constructor application?
181   case CoreUtils.exprIsConApp_maybe app of
182     -- Is this a tuple construction?
183     Just (dc, args) -> if DataCon.isTupleCon dc 
184       then
185         flattenBuildTupleExpr binds (dataConAppArgs dc args)
186       else
187         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
188     otherwise ->
189       -- Normal function application
190       let ((Var f), args) = collectArgs app in
191       let fname = Name.getOccString f in
192       if fname == "fst" || fname == "snd" then do
193         (args', Tuple [a, b]) <- flattenExpr binds (last args)
194         return (args', if fname == "fst" then a else b)
195       else if fname == "patError" then do
196         -- This is essentially don't care, since the program will error out
197         -- here. We'll just define undriven signals here.
198         let (argtys, resty) = Type.splitFunTys $ CoreUtils.exprType app
199         args <- mapM genSignals argtys
200         res <- genSignals resty
201         mapM (Traversable.mapM (addNameHint "NC")) args
202         Traversable.mapM (addNameHint "NC") res
203         return (args, res)
204       else if fname == "==" then do
205         -- Flatten the last two arguments (this skips the type arguments)
206         ([], a) <- flattenExpr binds (last $ init args)
207         ([], b) <- flattenExpr binds (last args)
208         res <- mkEqComparisons a b
209         return ([], res)
210       else if fname == "fromInteger" then do
211         let [to_ty, to_dict, val] = args 
212         -- We assume this is an application of the GHC.Integer.smallInteger
213         -- function to a literal
214         let App smallint (Lit lit) = val
215         let (Literal.MachInt int) = lit
216         let ty = CoreUtils.exprType app
217         sig_id <- genSignalId SigInternal ty
218         -- TODO: fromInteger is defined for more types than just SizedWord
219         let len = sized_word_len ty
220         -- Use a to_unsigned to translate the number (a natural) to an unsiged
221         -- (array of bits)
222         let lit_str = "to_unsigned(" ++ (show int) ++ ", " ++ (show len) ++ ")"
223         -- Set the signal to our literal unconditionally, but add the type so
224         -- the literal will be typecast to the proper type.
225         addDef $ UncondDef (Right $ Literal lit_str (Just ty)) sig_id
226         return ([], Single sig_id)
227       else
228         flattenApplicationExpr binds (CoreUtils.exprType app) f args
229   where
230     mkEqComparisons :: SignalMap -> SignalMap -> FlattenState SignalMap
231     mkEqComparisons a b = do
232       let zipped = zipValueMaps a b
233       Traversable.mapM mkEqComparison zipped
234
235     mkEqComparison :: (SignalId, SignalId) -> FlattenState SignalId
236     mkEqComparison (a, b) = do
237       -- Generate a signal to hold our result
238       res <- genSignalId SigInternal TysWiredIn.boolTy
239       -- Add a name hint to the signal
240       addNameHint ("s" ++ show a ++ "_eq_s" ++ show b) res
241       addDef (UncondDef (Right $ Eq a b) res)
242       return res
243
244     flattenBuildTupleExpr binds args = do
245       -- Flatten each of our args
246       flat_args <- (mapM (flattenExpr binds) args)
247       -- Check and split each of the arguments
248       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
249       let res = Tuple arg_ress
250       return ([], res)
251
252 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
253   (b_args, b_res) <- flattenExpr binds bexpr
254   if not (null b_args)
255     then
256       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
257     else do
258       let binds' = (b, Left b_res) : binds
259       -- Add name hints to the generated signals
260       let binder_name = Name.getOccString b
261       Traversable.mapM (addNameHint binder_name) b_res
262       flattenExpr binds' expr
263
264 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
265
266 flattenExpr binds expr@(Case scrut b _ alts) = do
267   -- TODO: Special casing for higher order functions
268   -- Flatten the scrutinee
269   (_, res) <- flattenExpr binds scrut
270   -- Put the scrutinee in the BindMap
271   let binds' = (b, Left res) : binds
272   case alts of
273     [alt] -> flattenSingleAltCaseExpr binds' res b alt
274     -- Reverse the alternatives, so the __DEFAULT alternative ends up last
275     otherwise -> flattenMultipleAltCaseExpr binds' res b (reverse alts)
276   where
277     flattenSingleAltCaseExpr ::
278       BindMap
279                                 -- A list of bindings in effect
280       -> SignalMap              -- The scrutinee
281       -> CoreBndr               -- The binder to bind the scrutinee to
282       -> CoreAlt                -- The single alternative
283       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
284
285     flattenSingleAltCaseExpr binds scrut b alt@(DataAlt datacon, bind_vars, expr) =
286       if DataCon.isTupleCon datacon
287         then do
288           -- Unpack the scrutinee (which must be a variable bound to a tuple) in
289           -- the existing bindings list and get the portname map for each of
290           -- it's elements.
291           let Tuple tuple_sigs = scrut
292           -- Add name hints to the returned signals
293           let binder_name = Name.getOccString b
294           Monad.zipWithM (\name  sigs -> Traversable.mapM (addNameHint $ Name.getOccString name) sigs) bind_vars tuple_sigs
295           -- Merge our existing binds with the new binds.
296           let binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
297           -- Expand the expression with the new binds list
298           flattenExpr binds' expr
299         else
300           if null bind_vars
301             then
302               -- DataAlts without arguments don't need processing
303               -- (flattenMultipleAltCaseExpr will have done this already).
304               flattenExpr binds expr
305             else
306               error $ "Dataconstructors other than tuple constructors cannot have binder arguments in case pattern of alternative: " ++ (showSDoc $ ppr alt)
307
308     flattenSingleAltCaseExpr binds _ _ alt@(DEFAULT, [], expr) =
309       flattenExpr binds expr
310       
311     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
312
313     flattenMultipleAltCaseExpr ::
314       BindMap
315                                 -- A list of bindings in effect
316       -> SignalMap              -- The scrutinee
317       -> CoreBndr               -- The binder to bind the scrutinee to
318       -> [CoreAlt]              -- The alternatives
319       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
320
321     flattenMultipleAltCaseExpr binds scrut b (a:a':alts) = do
322       (args, res) <- flattenSingleAltCaseExpr binds scrut b a
323       (args', res') <- flattenMultipleAltCaseExpr binds scrut b (a':alts)
324       case a of
325         (DataAlt datacon, bind_vars, expr) -> do
326           lit <- dataConToLiteral datacon
327           -- The scrutinee must be a single signal
328           let Single sig = scrut
329           -- Create a signal that contains a boolean
330           boolsigid <- genSignalId SigInternal TysWiredIn.boolTy
331           addNameHint ("s" ++ show sig ++ "_eq_" ++ lit) boolsigid
332           let expr = EqLit sig lit
333           addDef (UncondDef (Right expr) boolsigid)
334           -- Create conditional assignments of either args/res or
335           -- args'/res based on boolsigid, and return the result.
336           -- TODO: It seems this adds the name hint twice?
337           our_args <- Monad.zipWithM (mkConditionals boolsigid) args args'
338           our_res  <- mkConditionals boolsigid res res'
339           return (our_args, our_res)
340         otherwise ->
341           error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr a)
342       where
343         -- Select either the first or second signal map depending on the value
344         -- of the first argument (True == first map, False == second map)
345         mkConditionals :: SignalId -> SignalMap -> SignalMap -> FlattenState SignalMap
346         mkConditionals boolsigid true false = do
347           let zipped = zipValueMaps true false
348           Traversable.mapM (mkConditional boolsigid) zipped
349
350         mkConditional :: SignalId -> (SignalId, SignalId) -> FlattenState SignalId
351         mkConditional boolsigid (true, false) = do
352           -- Create a new signal (true and false should be identically typed,
353           -- so it doesn't matter which one we copy).
354           res <- copySignal true
355           addDef (CondDef boolsigid true false res)
356           return res
357
358     flattenMultipleAltCaseExpr binds scrut b (a:alts) =
359       flattenSingleAltCaseExpr binds scrut b a
360
361 flattenExpr _ expr = do
362   error $ "Unsupported expression: " ++ (showSDoc $ ppr expr)
363
364 -- | Flatten a normal application expression
365 flattenApplicationExpr binds ty f args = do
366   -- Find the function to call
367   let func = appToHsFunction ty f args
368   -- Flatten each of our args
369   flat_args <- (mapM (flattenExpr binds) args)
370   -- Check and split each of the arguments
371   let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
372   -- Generate signals for our result
373   res <- genSignals ty
374   -- Add name hints to the generated signals
375   let resname = Name.getOccString f ++ "_res"
376   Traversable.mapM (addNameHint resname) res
377   -- Create the function application
378   let app = FApp {
379     appFunc = func,
380     appArgs = arg_ress,
381     appRes  = res
382   }
383   addDef app
384   return ([], res)
385 -- | Check a flattened expression to see if it is valid to use as a
386 --   function argument. The first argument is the original expression for
387 --   use in the error message.
388 checkArg arg flat =
389   let (args, res) = flat in
390   if not (null args)
391     then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
392     else flat 
393
394 -- | Translates a dataconstructor without arguments to the corresponding
395 --   literal.
396 dataConToLiteral :: DataCon.DataCon -> FlattenState String
397 dataConToLiteral datacon = do
398   let tycon = DataCon.dataConTyCon datacon
399   let tyname = TyCon.tyConName tycon
400   case Name.getOccString tyname of
401     -- TODO: Do something more robust than string matching
402     "Bit"      -> do
403       let dcname = DataCon.dataConName datacon
404       let lit = case Name.getOccString dcname of "High" -> "'1'"; "Low" -> "'0'"
405       return lit
406     "Bool" -> do
407       let dcname = DataCon.dataConName datacon
408       let lit = case Name.getOccString dcname of "True" -> "true"; "False" -> "false"
409       return lit
410     otherwise ->
411       error $ "Literals of type " ++ (Name.getOccString tyname) ++ " not supported."
412
413 appToHsFunction ::
414   Type.Type       -- ^ The return type
415   -> Var.Var      -- ^ The function to call
416   -> [CoreExpr]   -- ^ The function arguments
417   -> HsFunction   -- ^ The needed HsFunction
418
419 appToHsFunction ty f args =
420   HsFunction hsname hsargs hsres
421   where
422     hsname = Name.getOccString f
423     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
424     hsres  = useAsPort (mkHsValueMap ty)
425
426 -- | Filters non-state signals and returns the state number and signal id for
427 --   state values.
428 filterState ::
429   SignalId                       -- | The signal id to look at
430   -> HsValueUse                  -- | How is this signal used?
431   -> Maybe (StateId, SignalId )  -- | The state num and signal id, if this
432                                  --   signal was used as state
433
434 filterState id (State num) = 
435   Just (num, id)
436 filterState _ _ = Nothing
437
438 -- | Returns a list of the state number and signal id of all used-as-state
439 --   signals in the given maps.
440 stateList ::
441   HsUseMap
442   -> (SignalMap)
443   -> [(StateId, SignalId)]
444
445 stateList uses signals =
446     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
447   
448 -- vim: set ts=8 sw=2 sts=2 expandtab: