Remove the now obsolete getOwnStates.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import qualified Control.Monad as Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified TyCon
11 import qualified CoreUtils
12 import qualified TysWiredIn
13 import qualified IdInfo
14 import qualified Data.Traversable as Traversable
15 import qualified Data.Foldable as Foldable
16 import Control.Applicative
17 import Outputable ( showSDoc, ppr )
18 import qualified Control.Monad.State as State
19
20 import HsValueMap
21 import TranslatorTypes
22 import FlattenTypes
23
24 -- Extract the arguments from a data constructor application (that is, the
25 -- normal args, leaving out the type args).
26 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
27 dataConAppArgs dc args =
28     drop tycount args
29   where
30     tycount = length $ DataCon.dataConAllTyVars dc
31
32 genSignals ::
33   Type.Type
34   -> FlattenState SignalMap
35
36 genSignals ty =
37   -- First generate a map with the right structure containing the types, and
38   -- generate signals for each of them.
39   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
40
41 -- | Marks a signal as the given SigUse, if its id is in the list of id's
42 --   given.
43 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
44 markSignals use ids (id, info) =
45   (id, info')
46   where
47     info' = if id `elem` ids then info { sigUse = use} else info
48
49 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
50 markSignal use id = markSignals use [id]
51
52 -- | Flatten a haskell function
53 flattenFunction ::
54   HsFunction                      -- ^ The function to flatten
55   -> CoreBind                     -- ^ The function value
56   -> FlatFunction                 -- ^ The resulting flat function
57
58 flattenFunction _ (Rec _) = error "Recursive binders not supported"
59 flattenFunction hsfunc bind@(NonRec var expr) =
60   FlatFunction args res defs sigs
61   where
62     init_state        = ([], [], 0)
63     (fres, end_state) = State.runState (flattenTopExpr hsfunc expr) init_state
64     (defs, sigs, _)   = end_state
65     (args, res)       = fres
66
67 flattenTopExpr ::
68   HsFunction
69   -> CoreExpr
70   -> FlattenState ([SignalMap], SignalMap)
71
72 flattenTopExpr hsfunc expr = do
73   -- Flatten the expression
74   (args, res) <- flattenExpr [] expr
75   
76   -- Join the signal ids and uses together
77   let zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
78   let zipped_res = zipValueMaps res (hsFuncRes hsfunc)
79   -- Set the signal uses for each argument / result, possibly updating
80   -- argument or result signals.
81   args' <- mapM (Traversable.mapM $ hsUseToSigUse args_use) zipped_args
82   res' <- Traversable.mapM (hsUseToSigUse res_use) zipped_res
83   return (args', res')
84   where
85     args_use Port = SigPortIn
86     args_use (State n) = SigStateOld n
87     res_use Port = SigPortOut
88     res_use (State n) = SigStateNew n
89
90
91 hsUseToSigUse :: 
92   (HsValueUse -> SigUse)      -- ^ A function to actually map the use value
93   -> (SignalId, HsValueUse)   -- ^ The signal to look at and its use
94   -> FlattenState SignalId    -- ^ The resulting signal. This is probably the
95                               --   same as the input, but it could be different.
96 hsUseToSigUse f (id, use) = do
97   info <- getSignalInfo id
98   id' <- case sigUse info of 
99     -- Internal signals can be marked as different uses freely.
100     SigInternal -> do
101       return id
102     -- Signals that already have another use, must be duplicated before
103     -- marking. This prevents signals mapping to the same input or output
104     -- port or state variables and ports overlapping, etc.
105     otherwise -> do
106       duplicateSignal id
107   setSignalInfo id' (info { sigUse = f use})
108   return id'
109
110 -- | Creates a new internal signal with the same type as the given signal
111 copySignal :: SignalId -> FlattenState SignalId
112 copySignal id = do
113   -- Find the type of the original signal
114   info <- getSignalInfo id
115   let ty = sigTy info
116   -- Generate a new signal (which is SigInternal for now, that will be
117   -- sorted out later on).
118   genSignalId SigInternal ty
119
120 -- | Duplicate the given signal, assigning its value to the new signal.
121 --   Returns the new signal id.
122 duplicateSignal :: SignalId -> FlattenState SignalId
123 duplicateSignal id = do
124   -- Create a new signal
125   id' <- copySignal id
126   -- Assign the old signal to the new signal
127   addDef $ UncondDef (Left id) id'
128   -- Replace the signal with the new signal
129   return id'
130         
131 flattenExpr ::
132   BindMap
133   -> CoreExpr
134   -> FlattenState ([SignalMap], SignalMap)
135
136 flattenExpr binds lam@(Lam b expr) = do
137   -- Find the type of the binder
138   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
139   -- Create signal names for the binder
140   defs <- genSignals arg_ty
141   -- Add name hints to the generated signals
142   let binder_name = Name.getOccString b
143   Traversable.mapM (addNameHint binder_name) defs
144   let binds' = (b, Left defs):binds
145   (args, res) <- flattenExpr binds' expr
146   return (defs : args, res)
147
148 flattenExpr binds var@(Var id) =
149   case Var.globalIdVarDetails id of
150     IdInfo.NotGlobalId ->
151       let 
152         bind = Maybe.fromMaybe
153           (error $ "Local value " ++ Name.getOccString id ++ " is unknown")
154           (lookup id binds) 
155       in
156         case bind of
157           Left sig_use -> return ([], sig_use)
158           Right _ -> error "Higher order functions not supported."
159     IdInfo.DataConWorkId datacon -> do
160       if DataCon.isTupleCon datacon && (null $ DataCon.dataConAllTyVars datacon)
161         then do
162           -- Empty tuple construction
163           return ([], Tuple [])
164         else do
165           lit <- dataConToLiteral datacon
166           let ty = CoreUtils.exprType var
167           sig_id <- genSignalId SigInternal ty
168           -- Add a name hint to the signal
169           addNameHint (Name.getOccString id) sig_id
170           addDef (UncondDef (Right $ Literal lit) sig_id)
171           return ([], Single sig_id)
172     otherwise ->
173       error $ "Ids other than local vars and dataconstructors not supported: " ++ (showSDoc $ ppr id)
174
175 flattenExpr binds app@(App _ _) = do
176   -- Is this a data constructor application?
177   case CoreUtils.exprIsConApp_maybe app of
178     -- Is this a tuple construction?
179     Just (dc, args) -> if DataCon.isTupleCon dc 
180       then
181         flattenBuildTupleExpr binds (dataConAppArgs dc args)
182       else
183         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
184     otherwise ->
185       -- Normal function application
186       let ((Var f), args) = collectArgs app in
187       let fname = Name.getOccString f in
188       if fname == "fst" || fname == "snd" then do
189         (args', Tuple [a, b]) <- flattenExpr binds (last args)
190         return (args', if fname == "fst" then a else b)
191       else if fname == "patError" then do
192         -- This is essentially don't care, since the program will error out
193         -- here. We'll just define undriven signals here.
194         let (argtys, resty) = Type.splitFunTys $ CoreUtils.exprType app
195         args <- mapM genSignals argtys
196         res <- genSignals resty
197         mapM (Traversable.mapM (addNameHint "NC")) args
198         Traversable.mapM (addNameHint "NC") res
199         return (args, res)
200       else if fname == "==" then do
201         -- Flatten the last two arguments (this skips the type arguments)
202         ([], a) <- flattenExpr binds (last $ init args)
203         ([], b) <- flattenExpr binds (last args)
204         res <- mkEqComparisons a b
205         return ([], res)
206       else
207         flattenApplicationExpr binds (CoreUtils.exprType app) f args
208   where
209     mkEqComparisons :: SignalMap -> SignalMap -> FlattenState SignalMap
210     mkEqComparisons a b = do
211       let zipped = zipValueMaps a b
212       Traversable.mapM mkEqComparison zipped
213
214     mkEqComparison :: (SignalId, SignalId) -> FlattenState SignalId
215     mkEqComparison (a, b) = do
216       -- Generate a signal to hold our result
217       res <- genSignalId SigInternal TysWiredIn.boolTy
218       -- Add a name hint to the signal
219       addNameHint ("s" ++ show a ++ "_eq_s" ++ show b) res
220       addDef (UncondDef (Right $ Eq a b) res)
221       return res
222
223     flattenBuildTupleExpr binds args = do
224       -- Flatten each of our args
225       flat_args <- (State.mapM (flattenExpr binds) args)
226       -- Check and split each of the arguments
227       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
228       let res = Tuple arg_ress
229       return ([], res)
230
231     -- | Flatten a normal application expression
232     flattenApplicationExpr binds ty f args = do
233       -- Find the function to call
234       let func = appToHsFunction ty f args
235       -- Flatten each of our args
236       flat_args <- (State.mapM (flattenExpr binds) args)
237       -- Check and split each of the arguments
238       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
239       -- Generate signals for our result
240       res <- genSignals ty
241       -- Add name hints to the generated signals
242       let resname = Name.getOccString f ++ "_res"
243       Traversable.mapM (addNameHint resname) res
244       -- Create the function application
245       let app = FApp {
246         appFunc = func,
247         appArgs = arg_ress,
248         appRes  = res
249       }
250       addDef app
251       return ([], res)
252     -- | Check a flattened expression to see if it is valid to use as a
253     --   function argument. The first argument is the original expression for
254     --   use in the error message.
255     checkArg arg flat =
256       let (args, res) = flat in
257       if not (null args)
258         then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
259         else flat 
260
261 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
262   (b_args, b_res) <- flattenExpr binds bexpr
263   if not (null b_args)
264     then
265       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
266     else do
267       let binds' = (b, Left b_res) : binds
268       -- Add name hints to the generated signals
269       let binder_name = Name.getOccString b
270       Traversable.mapM (addNameHint binder_name) b_res
271       flattenExpr binds' expr
272
273 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
274
275 flattenExpr binds expr@(Case scrut b _ alts) = do
276   -- TODO: Special casing for higher order functions
277   -- Flatten the scrutinee
278   (_, res) <- flattenExpr binds scrut
279   case alts of
280     -- TODO include b in the binds list
281     [alt] -> flattenSingleAltCaseExpr binds res b alt
282     -- Reverse the alternatives, so the __DEFAULT alternative ends up last
283     otherwise -> flattenMultipleAltCaseExpr binds res b (reverse alts)
284   where
285     flattenSingleAltCaseExpr ::
286       BindMap
287                                 -- A list of bindings in effect
288       -> SignalMap              -- The scrutinee
289       -> CoreBndr               -- The binder to bind the scrutinee to
290       -> CoreAlt                -- The single alternative
291       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
292
293     flattenSingleAltCaseExpr binds scrut b alt@(DataAlt datacon, bind_vars, expr) =
294       if DataCon.isTupleCon datacon
295         then do
296           -- Unpack the scrutinee (which must be a variable bound to a tuple) in
297           -- the existing bindings list and get the portname map for each of
298           -- it's elements.
299           let Tuple tuple_sigs = scrut
300           -- Add name hints to the returned signals
301           let binder_name = Name.getOccString b
302           Monad.zipWithM (\name  sigs -> Traversable.mapM (addNameHint $ Name.getOccString name) sigs) bind_vars tuple_sigs
303           -- Merge our existing binds with the new binds.
304           let binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
305           -- Expand the expression with the new binds list
306           flattenExpr binds' expr
307         else
308           if null bind_vars
309             then
310               -- DataAlts without arguments don't need processing
311               -- (flattenMultipleAltCaseExpr will have done this already).
312               flattenExpr binds expr
313             else
314               error $ "Dataconstructors other than tuple constructors cannot have binder arguments in case pattern of alternative: " ++ (showSDoc $ ppr alt)
315
316     flattenSingleAltCaseExpr binds _ _ alt@(DEFAULT, [], expr) =
317       flattenExpr binds expr
318       
319     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
320
321     flattenMultipleAltCaseExpr ::
322       BindMap
323                                 -- A list of bindings in effect
324       -> SignalMap              -- The scrutinee
325       -> CoreBndr               -- The binder to bind the scrutinee to
326       -> [CoreAlt]              -- The alternatives
327       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
328
329     flattenMultipleAltCaseExpr binds scrut b (a:a':alts) = do
330       (args, res) <- flattenSingleAltCaseExpr binds scrut b a
331       (args', res') <- flattenMultipleAltCaseExpr binds scrut b (a':alts)
332       case a of
333         (DataAlt datacon, bind_vars, expr) -> do
334           if isDontCare datacon 
335             then do
336               -- Completely skip the dontcare cases
337               return (args', res')
338             else do
339               lit <- dataConToLiteral datacon
340               -- The scrutinee must be a single signal
341               let Single sig = scrut
342               -- Create a signal that contains a boolean
343               boolsigid <- genSignalId SigInternal TysWiredIn.boolTy
344               addNameHint ("s" ++ show sig ++ "_eq_" ++ lit) boolsigid
345               let expr = EqLit sig lit
346               addDef (UncondDef (Right expr) boolsigid)
347               -- Create conditional assignments of either args/res or
348               -- args'/res based on boolsigid, and return the result.
349               -- TODO: It seems this adds the name hint twice?
350               our_args <- Monad.zipWithM (mkConditionals boolsigid) args args'
351               our_res  <- mkConditionals boolsigid res res'
352               return (our_args, our_res)
353         otherwise ->
354           error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr a)
355       where
356         -- Select either the first or second signal map depending on the value
357         -- of the first argument (True == first map, False == second map)
358         mkConditionals :: SignalId -> SignalMap -> SignalMap -> FlattenState SignalMap
359         mkConditionals boolsigid true false = do
360           let zipped = zipValueMaps true false
361           Traversable.mapM (mkConditional boolsigid) zipped
362
363         mkConditional :: SignalId -> (SignalId, SignalId) -> FlattenState SignalId
364         mkConditional boolsigid (true, false) = do
365           -- Create a new signal (true and false should be identically typed,
366           -- so it doesn't matter which one we copy).
367           res <- copySignal true
368           addDef (CondDef boolsigid true false res)
369           return res
370
371     flattenMultipleAltCaseExpr binds scrut b (a:alts) =
372       flattenSingleAltCaseExpr binds scrut b a
373
374 flattenExpr _ expr = do
375   error $ "Unsupported expression: " ++ (showSDoc $ ppr expr)
376
377 -- | Is the given data constructor a dontcare?
378 isDontCare :: DataCon.DataCon -> Bool
379 isDontCare datacon =
380   case Name.getOccString tyname of
381     -- TODO: Do something more robust than string matching
382     "Bit" ->
383       Name.getOccString dcname  == "DontCare"
384     otherwise ->
385       False
386   where
387     tycon = DataCon.dataConTyCon datacon
388     tyname = TyCon.tyConName tycon
389     dcname = DataCon.dataConName datacon
390
391 -- | Translates a dataconstructor without arguments to the corresponding
392 --   literal.
393 dataConToLiteral :: DataCon.DataCon -> FlattenState String
394 dataConToLiteral datacon = do
395   let tycon = DataCon.dataConTyCon datacon
396   let tyname = TyCon.tyConName tycon
397   case Name.getOccString tyname of
398     -- TODO: Do something more robust than string matching
399     "Bit"      -> do
400       let dcname = DataCon.dataConName datacon
401       let lit = case Name.getOccString dcname of "High" -> "'1'"; "Low" -> "'0'"; "DontCare" -> "'-'"
402       return lit
403     "Bool" -> do
404       let dcname = DataCon.dataConName datacon
405       let lit = case Name.getOccString dcname of "True" -> "true"; "False" -> "false"
406       return lit
407     otherwise ->
408       error $ "Literals of type " ++ (Name.getOccString tyname) ++ " not supported."
409
410 appToHsFunction ::
411   Type.Type       -- ^ The return type
412   -> Var.Var      -- ^ The function to call
413   -> [CoreExpr]   -- ^ The function arguments
414   -> HsFunction   -- ^ The needed HsFunction
415
416 appToHsFunction ty f args =
417   HsFunction hsname hsargs hsres
418   where
419     hsname = Name.getOccString f
420     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
421     hsres  = useAsPort (mkHsValueMap ty)
422
423 -- | Filters non-state signals and returns the state number and signal id for
424 --   state values.
425 filterState ::
426   SignalId                       -- | The signal id to look at
427   -> HsValueUse                  -- | How is this signal used?
428   -> Maybe (StateId, SignalId )  -- | The state num and signal id, if this
429                                  --   signal was used as state
430
431 filterState id (State num) = 
432   Just (num, id)
433 filterState _ _ = Nothing
434
435 -- | Returns a list of the state number and signal id of all used-as-state
436 --   signals in the given maps.
437 stateList ::
438   HsUseMap
439   -> (SignalMap)
440   -> [(StateId, SignalId)]
441
442 stateList uses signals =
443     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
444   
445 -- vim: set ts=8 sw=2 sts=2 expandtab: