Enable the DontCare value for Bit again.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import Control.Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified TyCon
11 import qualified CoreUtils
12 import qualified TysWiredIn
13 import qualified IdInfo
14 import qualified Data.Traversable as Traversable
15 import qualified Data.Foldable as Foldable
16 import Control.Applicative
17 import Outputable ( showSDoc, ppr )
18 import qualified Control.Monad.State as State
19
20 import HsValueMap
21 import TranslatorTypes
22 import FlattenTypes
23
24 -- Extract the arguments from a data constructor application (that is, the
25 -- normal args, leaving out the type args).
26 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
27 dataConAppArgs dc args =
28     drop tycount args
29   where
30     tycount = length $ DataCon.dataConAllTyVars dc
31
32 genSignals ::
33   Type.Type
34   -> FlattenState SignalMap
35
36 genSignals ty =
37   -- First generate a map with the right structure containing the types, and
38   -- generate signals for each of them.
39   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
40
41 -- | Marks a signal as the given SigUse, if its id is in the list of id's
42 --   given.
43 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
44 markSignals use ids (id, info) =
45   (id, info')
46   where
47     info' = if id `elem` ids then info { sigUse = use} else info
48
49 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
50 markSignal use id = markSignals use [id]
51
52 -- | Flatten a haskell function
53 flattenFunction ::
54   HsFunction                      -- ^ The function to flatten
55   -> CoreBind                     -- ^ The function value
56   -> FlatFunction                 -- ^ The resulting flat function
57
58 flattenFunction _ (Rec _) = error "Recursive binders not supported"
59 flattenFunction hsfunc bind@(NonRec var expr) =
60   FlatFunction args res defs sigs
61   where
62     init_state        = ([], [], 0)
63     (fres, end_state) = State.runState (flattenTopExpr hsfunc expr) init_state
64     (defs, sigs, _)   = end_state
65     (args, res)       = fres
66
67 flattenTopExpr ::
68   HsFunction
69   -> CoreExpr
70   -> FlattenState ([SignalMap], SignalMap)
71
72 flattenTopExpr hsfunc expr = do
73   -- Flatten the expression
74   (args, res) <- flattenExpr [] expr
75   
76   -- Join the signal ids and uses together
77   let zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
78   let zipped_res = zipValueMaps res (hsFuncRes hsfunc)
79   -- Set the signal uses for each argument / result, possibly updating
80   -- argument or result signals.
81   args' <- mapM (Traversable.mapM $ hsUseToSigUse args_use) zipped_args
82   res' <- Traversable.mapM (hsUseToSigUse res_use) zipped_res
83   return (args', res')
84   where
85     args_use Port = SigPortIn
86     args_use (State n) = SigStateOld n
87     res_use Port = SigPortOut
88     res_use (State n) = SigStateNew n
89
90
91 hsUseToSigUse :: 
92   (HsValueUse -> SigUse)      -- ^ A function to actually map the use value
93   -> (SignalId, HsValueUse)   -- ^ The signal to look at and its use
94   -> FlattenState SignalId    -- ^ The resulting signal. This is probably the
95                               --   same as the input, but it could be different.
96 hsUseToSigUse f (id, use) = do
97   info <- getSignalInfo id
98   id' <- case sigUse info of 
99     -- Internal signals can be marked as different uses freely.
100     SigInternal -> do
101       return id
102     -- Signals that already have another use, must be duplicated before
103     -- marking. This prevents signals mapping to the same input or output
104     -- port or state variables and ports overlapping, etc.
105     otherwise -> do
106       duplicateSignal id
107   setSignalInfo id' (info { sigUse = f use})
108   return id'
109
110 -- | Creates a new internal signal with the same type as the given signal
111 copySignal :: SignalId -> FlattenState SignalId
112 copySignal id = do
113   -- Find the type of the original signal
114   info <- getSignalInfo id
115   let ty = sigTy info
116   -- Generate a new signal (which is SigInternal for now, that will be
117   -- sorted out later on).
118   genSignalId SigInternal ty
119
120 -- | Duplicate the given signal, assigning its value to the new signal.
121 --   Returns the new signal id.
122 duplicateSignal :: SignalId -> FlattenState SignalId
123 duplicateSignal id = do
124   -- Create a new signal
125   id' <- copySignal id
126   -- Assign the old signal to the new signal
127   addDef $ UncondDef (Left id) id'
128   -- Replace the signal with the new signal
129   return id'
130         
131 flattenExpr ::
132   BindMap
133   -> CoreExpr
134   -> FlattenState ([SignalMap], SignalMap)
135
136 flattenExpr binds lam@(Lam b expr) = do
137   -- Find the type of the binder
138   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
139   -- Create signal names for the binder
140   defs <- genSignals arg_ty
141   let binds' = (b, Left defs):binds
142   (args, res) <- flattenExpr binds' expr
143   return (defs : args, res)
144
145 flattenExpr binds var@(Var id) =
146   case Var.globalIdVarDetails id of
147     IdInfo.NotGlobalId ->
148       let 
149         bind = Maybe.fromMaybe
150           (error $ "Local value " ++ Name.getOccString id ++ " is unknown")
151           (lookup id binds) 
152       in
153         case bind of
154           Left sig_use -> return ([], sig_use)
155           Right _ -> error "Higher order functions not supported."
156     IdInfo.DataConWorkId datacon -> do
157       lit <- dataConToLiteral datacon
158       let ty = CoreUtils.exprType var
159       id <- genSignalId SigInternal ty
160       addDef (UncondDef (Right $ Literal lit) id)
161       return ([], Single id)
162     otherwise ->
163       error $ "Ids other than local vars and dataconstructors not supported: " ++ (showSDoc $ ppr id)
164
165 flattenExpr binds app@(App _ _) = do
166   -- Is this a data constructor application?
167   case CoreUtils.exprIsConApp_maybe app of
168     -- Is this a tuple construction?
169     Just (dc, args) -> if DataCon.isTupleCon dc 
170       then
171         flattenBuildTupleExpr binds (dataConAppArgs dc args)
172       else
173         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
174     otherwise ->
175       -- Normal function application
176       let ((Var f), args) = collectArgs app in
177       let fname = Name.getOccString f in
178       if fname == "fst" || fname == "snd" then do
179         (args', Tuple [a, b]) <- flattenExpr binds (last args)
180         return (args', if fname == "fst" then a else b)
181       else if fname == "patError" then do
182         -- This is essentially don't care, since the program will error out
183         -- here. We'll just define undriven signals here.
184         let (argtys, resty) = Type.splitFunTys $ CoreUtils.exprType app
185         args <- mapM genSignals argtys
186         res <- genSignals resty
187         return (args, res)
188       else if fname == "==" then do
189         -- Flatten the last two arguments (this skips the type arguments)
190         ([], a) <- flattenExpr binds (last $ init args)
191         ([], b) <- flattenExpr binds (last args)
192         res <- mkEqComparisons a b
193         return ([], res)
194       else
195         flattenApplicationExpr binds (CoreUtils.exprType app) f args
196   where
197     mkEqComparisons :: SignalMap -> SignalMap -> FlattenState SignalMap
198     mkEqComparisons a b = do
199       let zipped = zipValueMaps a b
200       Traversable.mapM mkEqComparison zipped
201
202     mkEqComparison :: (SignalId, SignalId) -> FlattenState SignalId
203     mkEqComparison (a, b) = do
204       -- Generate a signal to hold our result
205       res <- genSignalId SigInternal TysWiredIn.boolTy
206       addDef (UncondDef (Right $ Eq a b) res)
207       return res
208
209     flattenBuildTupleExpr binds args = do
210       -- Flatten each of our args
211       flat_args <- (State.mapM (flattenExpr binds) args)
212       -- Check and split each of the arguments
213       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
214       let res = Tuple arg_ress
215       return ([], res)
216
217     -- | Flatten a normal application expression
218     flattenApplicationExpr binds ty f args = do
219       -- Find the function to call
220       let func = appToHsFunction ty f args
221       -- Flatten each of our args
222       flat_args <- (State.mapM (flattenExpr binds) args)
223       -- Check and split each of the arguments
224       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
225       -- Generate signals for our result
226       res <- genSignals ty
227       -- Create the function application
228       let app = FApp {
229         appFunc = func,
230         appArgs = arg_ress,
231         appRes  = res
232       }
233       addDef app
234       return ([], res)
235     -- | Check a flattened expression to see if it is valid to use as a
236     --   function argument. The first argument is the original expression for
237     --   use in the error message.
238     checkArg arg flat =
239       let (args, res) = flat in
240       if not (null args)
241         then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
242         else flat 
243
244 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
245   (b_args, b_res) <- flattenExpr binds bexpr
246   if not (null b_args)
247     then
248       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
249     else
250       let binds' = (b, Left b_res) : binds in
251       flattenExpr binds' expr
252
253 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
254
255 flattenExpr binds expr@(Case scrut b _ alts) = do
256   -- TODO: Special casing for higher order functions
257   -- Flatten the scrutinee
258   (_, res) <- flattenExpr binds scrut
259   case alts of
260     -- TODO include b in the binds list
261     [alt] -> flattenSingleAltCaseExpr binds res b alt
262     -- Reverse the alternatives, so the __DEFAULT alternative ends up last
263     otherwise -> flattenMultipleAltCaseExpr binds res b (reverse alts)
264   where
265     flattenSingleAltCaseExpr ::
266       BindMap
267                                 -- A list of bindings in effect
268       -> SignalMap              -- The scrutinee
269       -> CoreBndr               -- The binder to bind the scrutinee to
270       -> CoreAlt                -- The single alternative
271       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
272
273     flattenSingleAltCaseExpr binds scrut b alt@(DataAlt datacon, bind_vars, expr) =
274       if DataCon.isTupleCon datacon
275         then
276           let
277             -- Unpack the scrutinee (which must be a variable bound to a tuple) in
278             -- the existing bindings list and get the portname map for each of
279             -- it's elements.
280             Tuple tuple_sigs = scrut
281             -- Merge our existing binds with the new binds.
282             binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
283           in
284             -- Expand the expression with the new binds list
285             flattenExpr binds' expr
286         else
287           if null bind_vars
288             then
289               -- DataAlts without arguments don't need processing
290               -- (flattenMultipleAltCaseExpr will have done this already).
291               flattenExpr binds expr
292             else
293               error $ "Dataconstructors other than tuple constructors cannot have binder arguments in case pattern of alternative: " ++ (showSDoc $ ppr alt)
294
295     flattenSingleAltCaseExpr binds _ _ alt@(DEFAULT, [], expr) =
296       flattenExpr binds expr
297       
298     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
299
300     flattenMultipleAltCaseExpr ::
301       BindMap
302                                 -- A list of bindings in effect
303       -> SignalMap              -- The scrutinee
304       -> CoreBndr               -- The binder to bind the scrutinee to
305       -> [CoreAlt]              -- The alternatives
306       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
307
308     flattenMultipleAltCaseExpr binds scrut b (a:a':alts) = do
309       (args, res) <- flattenSingleAltCaseExpr binds scrut b a
310       (args', res') <- flattenMultipleAltCaseExpr binds scrut b (a':alts)
311       case a of
312         (DataAlt datacon, bind_vars, expr) -> do
313           if isDontCare datacon 
314             then do
315               -- Completely skip the dontcare cases
316               return (args', res')
317             else do
318               lit <- dataConToLiteral datacon
319               -- The scrutinee must be a single signal
320               let Single sig = scrut
321               -- Create a signal that contains a boolean
322               boolsigid <- genSignalId SigInternal TysWiredIn.boolTy
323               let expr = EqLit sig lit
324               addDef (UncondDef (Right expr) boolsigid)
325               -- Create conditional assignments of either args/res or
326               -- args'/res based on boolsigid, and return the result.
327               our_args <- zipWithM (mkConditionals boolsigid) args args'
328               our_res  <- mkConditionals boolsigid res res'
329               return (our_args, our_res)
330         otherwise ->
331           error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr a)
332       where
333         -- Select either the first or second signal map depending on the value
334         -- of the first argument (True == first map, False == second map)
335         mkConditionals :: SignalId -> SignalMap -> SignalMap -> FlattenState SignalMap
336         mkConditionals boolsigid true false = do
337           let zipped = zipValueMaps true false
338           Traversable.mapM (mkConditional boolsigid) zipped
339
340         mkConditional :: SignalId -> (SignalId, SignalId) -> FlattenState SignalId
341         mkConditional boolsigid (true, false) = do
342           -- Create a new signal (true and false should be identically typed,
343           -- so it doesn't matter which one we copy).
344           res <- copySignal true
345           addDef (CondDef boolsigid true false res)
346           return res
347
348     flattenMultipleAltCaseExpr binds scrut b (a:alts) =
349       flattenSingleAltCaseExpr binds scrut b a
350
351 flattenExpr _ expr = do
352   error $ "Unsupported expression: " ++ (showSDoc $ ppr expr)
353
354 -- | Is the given data constructor a dontcare?
355 isDontCare :: DataCon.DataCon -> Bool
356 isDontCare datacon =
357   case Name.getOccString tyname of
358     -- TODO: Do something more robust than string matching
359     "Bit" ->
360       Name.getOccString dcname  == "DontCare"
361     otherwise ->
362       False
363   where
364     tycon = DataCon.dataConTyCon datacon
365     tyname = TyCon.tyConName tycon
366     dcname = DataCon.dataConName datacon
367
368 -- | Translates a dataconstructor without arguments to the corresponding
369 --   literal.
370 dataConToLiteral :: DataCon.DataCon -> FlattenState String
371 dataConToLiteral datacon = do
372   let tycon = DataCon.dataConTyCon datacon
373   let tyname = TyCon.tyConName tycon
374   case Name.getOccString tyname of
375     -- TODO: Do something more robust than string matching
376     "Bit"      -> do
377       let dcname = DataCon.dataConName datacon
378       let lit = case Name.getOccString dcname of "High" -> "'1'"; "Low" -> "'0'"; "DontCare" -> "'-'"
379       return lit
380     "Bool" -> do
381       let dcname = DataCon.dataConName datacon
382       let lit = case Name.getOccString dcname of "True" -> "true"; "False" -> "false"
383       return lit
384     otherwise ->
385       error $ "Literals of type " ++ (Name.getOccString tyname) ++ " not supported."
386
387 appToHsFunction ::
388   Type.Type       -- ^ The return type
389   -> Var.Var      -- ^ The function to call
390   -> [CoreExpr]   -- ^ The function arguments
391   -> HsFunction   -- ^ The needed HsFunction
392
393 appToHsFunction ty f args =
394   HsFunction hsname hsargs hsres
395   where
396     hsname = Name.getOccString f
397     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
398     hsres  = useAsPort (mkHsValueMap ty)
399
400 -- | Filters non-state signals and returns the state number and signal id for
401 --   state values.
402 filterState ::
403   SignalId                       -- | The signal id to look at
404   -> HsValueUse                  -- | How is this signal used?
405   -> Maybe (StateId, SignalId )  -- | The state num and signal id, if this
406                                  --   signal was used as state
407
408 filterState id (State num) = 
409   Just (num, id)
410 filterState _ _ = Nothing
411
412 -- | Returns a list of the state number and signal id of all used-as-state
413 --   signals in the given maps.
414 stateList ::
415   HsUseMap
416   -> (SignalMap)
417   -> [(StateId, SignalId)]
418
419 stateList uses signals =
420     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
421   
422 -- | Returns pairs of signals that should be mapped to state in this function.
423 getOwnStates ::
424   HsFunction                      -- | The function to look at
425   -> FlatFunction                 -- | The function to look at
426   -> [(StateId, SignalInfo, SignalInfo)]   
427         -- | The state signals. The first is the state number, the second the
428         --   signal to assign the current state to, the last is the signal
429         --   that holds the new state.
430
431 getOwnStates hsfunc flatfunc =
432   [(old_num, old_info, new_info) 
433     | (old_num, old_info) <- args_states
434     , (new_num, new_info) <- res_states
435     , old_num == new_num]
436   where
437     sigs = flat_sigs flatfunc
438     -- Translate args and res to lists of (statenum, sigid)
439     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
440     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
441     -- Replace the second tuple element with the corresponding SignalInfo
442     args_states = map (Arrow.second $ signalInfo sigs) args
443     res_states = map (Arrow.second $ signalInfo sigs) res
444
445     
446 -- vim: set ts=8 sw=2 sts=2 expandtab: