Allow references to global values without arguments.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import qualified Control.Monad as Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified TyCon
11 import qualified Literal
12 import qualified CoreUtils
13 import qualified TysWiredIn
14 import qualified IdInfo
15 import qualified Data.Traversable as Traversable
16 import qualified Data.Foldable as Foldable
17 import Control.Applicative
18 import Outputable ( showSDoc, ppr )
19 import qualified Control.Monad.Trans.State as State
20
21 import HsValueMap
22 import TranslatorTypes
23 import FlattenTypes
24
25 -- Extract the arguments from a data constructor application (that is, the
26 -- normal args, leaving out the type args).
27 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
28 dataConAppArgs dc args =
29     drop tycount args
30   where
31     tycount = length $ DataCon.dataConAllTyVars dc
32
33 genSignals ::
34   Type.Type
35   -> FlattenState SignalMap
36
37 genSignals ty =
38   -- First generate a map with the right structure containing the types, and
39   -- generate signals for each of them.
40   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
41
42 -- | Marks a signal as the given SigUse, if its id is in the list of id's
43 --   given.
44 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
45 markSignals use ids (id, info) =
46   (id, info')
47   where
48     info' = if id `elem` ids then info { sigUse = use} else info
49
50 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
51 markSignal use id = markSignals use [id]
52
53 -- | Flatten a haskell function
54 flattenFunction ::
55   HsFunction                      -- ^ The function to flatten
56   -> CoreBind                     -- ^ The function value
57   -> FlatFunction                 -- ^ The resulting flat function
58
59 flattenFunction _ (Rec _) = error "Recursive binders not supported"
60 flattenFunction hsfunc bind@(NonRec var expr) =
61   FlatFunction args res defs sigs
62   where
63     init_state        = ([], [], 0)
64     (fres, end_state) = State.runState (flattenTopExpr hsfunc expr) init_state
65     (defs, sigs, _)   = end_state
66     (args, res)       = fres
67
68 flattenTopExpr ::
69   HsFunction
70   -> CoreExpr
71   -> FlattenState ([SignalMap], SignalMap)
72
73 flattenTopExpr hsfunc expr = do
74   -- Flatten the expression
75   (args, res) <- flattenExpr [] expr
76   
77   -- Join the signal ids and uses together
78   let zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
79   let zipped_res = zipValueMaps res (hsFuncRes hsfunc)
80   -- Set the signal uses for each argument / result, possibly updating
81   -- argument or result signals.
82   args' <- mapM (Traversable.mapM $ hsUseToSigUse args_use) zipped_args
83   res' <- Traversable.mapM (hsUseToSigUse res_use) zipped_res
84   return (args', res')
85   where
86     args_use Port = SigPortIn
87     args_use (State n) = SigStateOld n
88     res_use Port = SigPortOut
89     res_use (State n) = SigStateNew n
90
91
92 hsUseToSigUse :: 
93   (HsValueUse -> SigUse)      -- ^ A function to actually map the use value
94   -> (SignalId, HsValueUse)   -- ^ The signal to look at and its use
95   -> FlattenState SignalId    -- ^ The resulting signal. This is probably the
96                               --   same as the input, but it could be different.
97 hsUseToSigUse f (id, use) = do
98   info <- getSignalInfo id
99   id' <- case sigUse info of 
100     -- Internal signals can be marked as different uses freely.
101     SigInternal -> do
102       return id
103     -- Signals that already have another use, must be duplicated before
104     -- marking. This prevents signals mapping to the same input or output
105     -- port or state variables and ports overlapping, etc.
106     otherwise -> do
107       duplicateSignal id
108   setSignalInfo id' (info { sigUse = f use})
109   return id'
110
111 -- | Creates a new internal signal with the same type as the given signal
112 copySignal :: SignalId -> FlattenState SignalId
113 copySignal id = do
114   -- Find the type of the original signal
115   info <- getSignalInfo id
116   let ty = sigTy info
117   -- Generate a new signal (which is SigInternal for now, that will be
118   -- sorted out later on).
119   genSignalId SigInternal ty
120
121 -- | Duplicate the given signal, assigning its value to the new signal.
122 --   Returns the new signal id.
123 duplicateSignal :: SignalId -> FlattenState SignalId
124 duplicateSignal id = do
125   -- Create a new signal
126   id' <- copySignal id
127   -- Assign the old signal to the new signal
128   addDef $ UncondDef (Left id) id'
129   -- Replace the signal with the new signal
130   return id'
131         
132 flattenExpr ::
133   BindMap
134   -> CoreExpr
135   -> FlattenState ([SignalMap], SignalMap)
136
137 flattenExpr binds lam@(Lam b expr) = do
138   -- Find the type of the binder
139   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
140   -- Create signal names for the binder
141   defs <- genSignals arg_ty
142   -- Add name hints to the generated signals
143   let binder_name = Name.getOccString b
144   Traversable.mapM (addNameHint binder_name) defs
145   let binds' = (b, Left defs):binds
146   (args, res) <- flattenExpr binds' expr
147   return (defs : args, res)
148
149 flattenExpr binds var@(Var id) =
150   case Var.globalIdVarDetails id of
151     IdInfo.NotGlobalId ->
152       let 
153         bind = Maybe.fromMaybe
154           (error $ "Local value " ++ Name.getOccString id ++ " is unknown")
155           (lookup id binds) 
156       in
157         case bind of
158           Left sig_use -> return ([], sig_use)
159           Right _ -> error "Higher order functions not supported."
160     IdInfo.DataConWorkId datacon -> do
161       if DataCon.isTupleCon datacon && (null $ DataCon.dataConAllTyVars datacon)
162         then do
163           -- Empty tuple construction
164           return ([], Tuple [])
165         else do
166           lit <- dataConToLiteral datacon
167           let ty = CoreUtils.exprType var
168           sig_id <- genSignalId SigInternal ty
169           -- Add a name hint to the signal
170           addNameHint (Name.getOccString id) sig_id
171           addDef (UncondDef (Right $ Literal lit) sig_id)
172           return ([], Single sig_id)
173     IdInfo.VanillaGlobal ->
174       -- Treat references to globals as an application with zero elements
175       flattenApplicationExpr binds (CoreUtils.exprType var) id []
176     otherwise ->
177       error $ "Ids other than local vars and dataconstructors not supported: " ++ (showSDoc $ ppr id)
178
179 flattenExpr binds app@(App _ _) = do
180   -- Is this a data constructor application?
181   case CoreUtils.exprIsConApp_maybe app of
182     -- Is this a tuple construction?
183     Just (dc, args) -> if DataCon.isTupleCon dc 
184       then
185         flattenBuildTupleExpr binds (dataConAppArgs dc args)
186       else
187         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
188     otherwise ->
189       -- Normal function application
190       let ((Var f), args) = collectArgs app in
191       let fname = Name.getOccString f in
192       if fname == "fst" || fname == "snd" then do
193         (args', Tuple [a, b]) <- flattenExpr binds (last args)
194         return (args', if fname == "fst" then a else b)
195       else if fname == "patError" then do
196         -- This is essentially don't care, since the program will error out
197         -- here. We'll just define undriven signals here.
198         let (argtys, resty) = Type.splitFunTys $ CoreUtils.exprType app
199         args <- mapM genSignals argtys
200         res <- genSignals resty
201         mapM (Traversable.mapM (addNameHint "NC")) args
202         Traversable.mapM (addNameHint "NC") res
203         return (args, res)
204       else if fname == "==" then do
205         -- Flatten the last two arguments (this skips the type arguments)
206         ([], a) <- flattenExpr binds (last $ init args)
207         ([], b) <- flattenExpr binds (last args)
208         res <- mkEqComparisons a b
209         return ([], res)
210       else
211         flattenApplicationExpr binds (CoreUtils.exprType app) f args
212   where
213     mkEqComparisons :: SignalMap -> SignalMap -> FlattenState SignalMap
214     mkEqComparisons a b = do
215       let zipped = zipValueMaps a b
216       Traversable.mapM mkEqComparison zipped
217
218     mkEqComparison :: (SignalId, SignalId) -> FlattenState SignalId
219     mkEqComparison (a, b) = do
220       -- Generate a signal to hold our result
221       res <- genSignalId SigInternal TysWiredIn.boolTy
222       -- Add a name hint to the signal
223       addNameHint ("s" ++ show a ++ "_eq_s" ++ show b) res
224       addDef (UncondDef (Right $ Eq a b) res)
225       return res
226
227     flattenBuildTupleExpr binds args = do
228       -- Flatten each of our args
229       flat_args <- (mapM (flattenExpr binds) args)
230       -- Check and split each of the arguments
231       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
232       let res = Tuple arg_ress
233       return ([], res)
234
235 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
236   (b_args, b_res) <- flattenExpr binds bexpr
237   if not (null b_args)
238     then
239       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
240     else do
241       let binds' = (b, Left b_res) : binds
242       -- Add name hints to the generated signals
243       let binder_name = Name.getOccString b
244       Traversable.mapM (addNameHint binder_name) b_res
245       flattenExpr binds' expr
246
247 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
248
249 flattenExpr binds expr@(Case scrut b _ alts) = do
250   -- TODO: Special casing for higher order functions
251   -- Flatten the scrutinee
252   (_, res) <- flattenExpr binds scrut
253   case alts of
254     -- TODO include b in the binds list
255     [alt] -> flattenSingleAltCaseExpr binds res b alt
256     -- Reverse the alternatives, so the __DEFAULT alternative ends up last
257     otherwise -> flattenMultipleAltCaseExpr binds res b (reverse alts)
258   where
259     flattenSingleAltCaseExpr ::
260       BindMap
261                                 -- A list of bindings in effect
262       -> SignalMap              -- The scrutinee
263       -> CoreBndr               -- The binder to bind the scrutinee to
264       -> CoreAlt                -- The single alternative
265       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
266
267     flattenSingleAltCaseExpr binds scrut b alt@(DataAlt datacon, bind_vars, expr) =
268       if DataCon.isTupleCon datacon
269         then do
270           -- Unpack the scrutinee (which must be a variable bound to a tuple) in
271           -- the existing bindings list and get the portname map for each of
272           -- it's elements.
273           let Tuple tuple_sigs = scrut
274           -- Add name hints to the returned signals
275           let binder_name = Name.getOccString b
276           Monad.zipWithM (\name  sigs -> Traversable.mapM (addNameHint $ Name.getOccString name) sigs) bind_vars tuple_sigs
277           -- Merge our existing binds with the new binds.
278           let binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
279           -- Expand the expression with the new binds list
280           flattenExpr binds' expr
281         else
282           if null bind_vars
283             then
284               -- DataAlts without arguments don't need processing
285               -- (flattenMultipleAltCaseExpr will have done this already).
286               flattenExpr binds expr
287             else
288               error $ "Dataconstructors other than tuple constructors cannot have binder arguments in case pattern of alternative: " ++ (showSDoc $ ppr alt)
289
290     flattenSingleAltCaseExpr binds _ _ alt@(DEFAULT, [], expr) =
291       flattenExpr binds expr
292       
293     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
294
295     flattenMultipleAltCaseExpr ::
296       BindMap
297                                 -- A list of bindings in effect
298       -> SignalMap              -- The scrutinee
299       -> CoreBndr               -- The binder to bind the scrutinee to
300       -> [CoreAlt]              -- The alternatives
301       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
302
303     flattenMultipleAltCaseExpr binds scrut b (a:a':alts) = do
304       (args, res) <- flattenSingleAltCaseExpr binds scrut b a
305       (args', res') <- flattenMultipleAltCaseExpr binds scrut b (a':alts)
306       case a of
307         (DataAlt datacon, bind_vars, expr) -> do
308           lit <- dataConToLiteral datacon
309           -- The scrutinee must be a single signal
310           let Single sig = scrut
311           -- Create a signal that contains a boolean
312           boolsigid <- genSignalId SigInternal TysWiredIn.boolTy
313           addNameHint ("s" ++ show sig ++ "_eq_" ++ lit) boolsigid
314           let expr = EqLit sig lit
315           addDef (UncondDef (Right expr) boolsigid)
316           -- Create conditional assignments of either args/res or
317           -- args'/res based on boolsigid, and return the result.
318           -- TODO: It seems this adds the name hint twice?
319           our_args <- Monad.zipWithM (mkConditionals boolsigid) args args'
320           our_res  <- mkConditionals boolsigid res res'
321           return (our_args, our_res)
322         otherwise ->
323           error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr a)
324       where
325         -- Select either the first or second signal map depending on the value
326         -- of the first argument (True == first map, False == second map)
327         mkConditionals :: SignalId -> SignalMap -> SignalMap -> FlattenState SignalMap
328         mkConditionals boolsigid true false = do
329           let zipped = zipValueMaps true false
330           Traversable.mapM (mkConditional boolsigid) zipped
331
332         mkConditional :: SignalId -> (SignalId, SignalId) -> FlattenState SignalId
333         mkConditional boolsigid (true, false) = do
334           -- Create a new signal (true and false should be identically typed,
335           -- so it doesn't matter which one we copy).
336           res <- copySignal true
337           addDef (CondDef boolsigid true false res)
338           return res
339
340     flattenMultipleAltCaseExpr binds scrut b (a:alts) =
341       flattenSingleAltCaseExpr binds scrut b a
342
343 flattenExpr _ expr = do
344   error $ "Unsupported expression: " ++ (showSDoc $ ppr expr)
345
346 -- | Flatten a normal application expression
347 flattenApplicationExpr binds ty f args = do
348   -- Find the function to call
349   let func = appToHsFunction ty f args
350   -- Flatten each of our args
351   flat_args <- (mapM (flattenExpr binds) args)
352   -- Check and split each of the arguments
353   let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
354   -- Generate signals for our result
355   res <- genSignals ty
356   -- Add name hints to the generated signals
357   let resname = Name.getOccString f ++ "_res"
358   Traversable.mapM (addNameHint resname) res
359   -- Create the function application
360   let app = FApp {
361     appFunc = func,
362     appArgs = arg_ress,
363     appRes  = res
364   }
365   addDef app
366   return ([], res)
367 -- | Check a flattened expression to see if it is valid to use as a
368 --   function argument. The first argument is the original expression for
369 --   use in the error message.
370 checkArg arg flat =
371   let (args, res) = flat in
372   if not (null args)
373     then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
374     else flat 
375
376 -- | Translates a dataconstructor without arguments to the corresponding
377 --   literal.
378 dataConToLiteral :: DataCon.DataCon -> FlattenState String
379 dataConToLiteral datacon = do
380   let tycon = DataCon.dataConTyCon datacon
381   let tyname = TyCon.tyConName tycon
382   case Name.getOccString tyname of
383     -- TODO: Do something more robust than string matching
384     "Bit"      -> do
385       let dcname = DataCon.dataConName datacon
386       let lit = case Name.getOccString dcname of "High" -> "'1'"; "Low" -> "'0'"
387       return lit
388     "Bool" -> do
389       let dcname = DataCon.dataConName datacon
390       let lit = case Name.getOccString dcname of "True" -> "true"; "False" -> "false"
391       return lit
392     otherwise ->
393       error $ "Literals of type " ++ (Name.getOccString tyname) ++ " not supported."
394
395 appToHsFunction ::
396   Type.Type       -- ^ The return type
397   -> Var.Var      -- ^ The function to call
398   -> [CoreExpr]   -- ^ The function arguments
399   -> HsFunction   -- ^ The needed HsFunction
400
401 appToHsFunction ty f args =
402   HsFunction hsname hsargs hsres
403   where
404     hsname = Name.getOccString f
405     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
406     hsres  = useAsPort (mkHsValueMap ty)
407
408 -- | Filters non-state signals and returns the state number and signal id for
409 --   state values.
410 filterState ::
411   SignalId                       -- | The signal id to look at
412   -> HsValueUse                  -- | How is this signal used?
413   -> Maybe (StateId, SignalId )  -- | The state num and signal id, if this
414                                  --   signal was used as state
415
416 filterState id (State num) = 
417   Just (num, id)
418 filterState _ _ = Nothing
419
420 -- | Returns a list of the state number and signal id of all used-as-state
421 --   signals in the given maps.
422 stateList ::
423   HsUseMap
424   -> (SignalMap)
425   -> [(StateId, SignalId)]
426
427 stateList uses signals =
428     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
429   
430 -- vim: set ts=8 sw=2 sts=2 expandtab: