Move some code out of the flattenExpr to global scope.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import qualified Control.Monad as Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified TyCon
11 import qualified Literal
12 import qualified CoreUtils
13 import qualified TysWiredIn
14 import qualified IdInfo
15 import qualified Data.Traversable as Traversable
16 import qualified Data.Foldable as Foldable
17 import Control.Applicative
18 import Outputable ( showSDoc, ppr )
19 import qualified Control.Monad.Trans.State as State
20
21 import HsValueMap
22 import TranslatorTypes
23 import FlattenTypes
24
25 -- Extract the arguments from a data constructor application (that is, the
26 -- normal args, leaving out the type args).
27 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
28 dataConAppArgs dc args =
29     drop tycount args
30   where
31     tycount = length $ DataCon.dataConAllTyVars dc
32
33 genSignals ::
34   Type.Type
35   -> FlattenState SignalMap
36
37 genSignals ty =
38   -- First generate a map with the right structure containing the types, and
39   -- generate signals for each of them.
40   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
41
42 -- | Marks a signal as the given SigUse, if its id is in the list of id's
43 --   given.
44 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
45 markSignals use ids (id, info) =
46   (id, info')
47   where
48     info' = if id `elem` ids then info { sigUse = use} else info
49
50 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
51 markSignal use id = markSignals use [id]
52
53 -- | Flatten a haskell function
54 flattenFunction ::
55   HsFunction                      -- ^ The function to flatten
56   -> CoreBind                     -- ^ The function value
57   -> FlatFunction                 -- ^ The resulting flat function
58
59 flattenFunction _ (Rec _) = error "Recursive binders not supported"
60 flattenFunction hsfunc bind@(NonRec var expr) =
61   FlatFunction args res defs sigs
62   where
63     init_state        = ([], [], 0)
64     (fres, end_state) = State.runState (flattenTopExpr hsfunc expr) init_state
65     (defs, sigs, _)   = end_state
66     (args, res)       = fres
67
68 flattenTopExpr ::
69   HsFunction
70   -> CoreExpr
71   -> FlattenState ([SignalMap], SignalMap)
72
73 flattenTopExpr hsfunc expr = do
74   -- Flatten the expression
75   (args, res) <- flattenExpr [] expr
76   
77   -- Join the signal ids and uses together
78   let zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
79   let zipped_res = zipValueMaps res (hsFuncRes hsfunc)
80   -- Set the signal uses for each argument / result, possibly updating
81   -- argument or result signals.
82   args' <- mapM (Traversable.mapM $ hsUseToSigUse args_use) zipped_args
83   res' <- Traversable.mapM (hsUseToSigUse res_use) zipped_res
84   return (args', res')
85   where
86     args_use Port = SigPortIn
87     args_use (State n) = SigStateOld n
88     res_use Port = SigPortOut
89     res_use (State n) = SigStateNew n
90
91
92 hsUseToSigUse :: 
93   (HsValueUse -> SigUse)      -- ^ A function to actually map the use value
94   -> (SignalId, HsValueUse)   -- ^ The signal to look at and its use
95   -> FlattenState SignalId    -- ^ The resulting signal. This is probably the
96                               --   same as the input, but it could be different.
97 hsUseToSigUse f (id, use) = do
98   info <- getSignalInfo id
99   id' <- case sigUse info of 
100     -- Internal signals can be marked as different uses freely.
101     SigInternal -> do
102       return id
103     -- Signals that already have another use, must be duplicated before
104     -- marking. This prevents signals mapping to the same input or output
105     -- port or state variables and ports overlapping, etc.
106     otherwise -> do
107       duplicateSignal id
108   setSignalInfo id' (info { sigUse = f use})
109   return id'
110
111 -- | Creates a new internal signal with the same type as the given signal
112 copySignal :: SignalId -> FlattenState SignalId
113 copySignal id = do
114   -- Find the type of the original signal
115   info <- getSignalInfo id
116   let ty = sigTy info
117   -- Generate a new signal (which is SigInternal for now, that will be
118   -- sorted out later on).
119   genSignalId SigInternal ty
120
121 -- | Duplicate the given signal, assigning its value to the new signal.
122 --   Returns the new signal id.
123 duplicateSignal :: SignalId -> FlattenState SignalId
124 duplicateSignal id = do
125   -- Create a new signal
126   id' <- copySignal id
127   -- Assign the old signal to the new signal
128   addDef $ UncondDef (Left id) id'
129   -- Replace the signal with the new signal
130   return id'
131         
132 flattenExpr ::
133   BindMap
134   -> CoreExpr
135   -> FlattenState ([SignalMap], SignalMap)
136
137 flattenExpr binds lam@(Lam b expr) = do
138   -- Find the type of the binder
139   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
140   -- Create signal names for the binder
141   defs <- genSignals arg_ty
142   -- Add name hints to the generated signals
143   let binder_name = Name.getOccString b
144   Traversable.mapM (addNameHint binder_name) defs
145   let binds' = (b, Left defs):binds
146   (args, res) <- flattenExpr binds' expr
147   return (defs : args, res)
148
149 flattenExpr binds var@(Var id) =
150   case Var.globalIdVarDetails id of
151     IdInfo.NotGlobalId ->
152       let 
153         bind = Maybe.fromMaybe
154           (error $ "Local value " ++ Name.getOccString id ++ " is unknown")
155           (lookup id binds) 
156       in
157         case bind of
158           Left sig_use -> return ([], sig_use)
159           Right _ -> error "Higher order functions not supported."
160     IdInfo.DataConWorkId datacon -> do
161       if DataCon.isTupleCon datacon && (null $ DataCon.dataConAllTyVars datacon)
162         then do
163           -- Empty tuple construction
164           return ([], Tuple [])
165         else do
166           lit <- dataConToLiteral datacon
167           let ty = CoreUtils.exprType var
168           sig_id <- genSignalId SigInternal ty
169           -- Add a name hint to the signal
170           addNameHint (Name.getOccString id) sig_id
171           addDef (UncondDef (Right $ Literal lit) sig_id)
172           return ([], Single sig_id)
173     otherwise ->
174       error $ "Ids other than local vars and dataconstructors not supported: " ++ (showSDoc $ ppr id)
175
176 flattenExpr binds app@(App _ _) = do
177   -- Is this a data constructor application?
178   case CoreUtils.exprIsConApp_maybe app of
179     -- Is this a tuple construction?
180     Just (dc, args) -> if DataCon.isTupleCon dc 
181       then
182         flattenBuildTupleExpr binds (dataConAppArgs dc args)
183       else
184         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
185     otherwise ->
186       -- Normal function application
187       let ((Var f), args) = collectArgs app in
188       let fname = Name.getOccString f in
189       if fname == "fst" || fname == "snd" then do
190         (args', Tuple [a, b]) <- flattenExpr binds (last args)
191         return (args', if fname == "fst" then a else b)
192       else if fname == "patError" then do
193         -- This is essentially don't care, since the program will error out
194         -- here. We'll just define undriven signals here.
195         let (argtys, resty) = Type.splitFunTys $ CoreUtils.exprType app
196         args <- mapM genSignals argtys
197         res <- genSignals resty
198         mapM (Traversable.mapM (addNameHint "NC")) args
199         Traversable.mapM (addNameHint "NC") res
200         return (args, res)
201       else if fname == "==" then do
202         -- Flatten the last two arguments (this skips the type arguments)
203         ([], a) <- flattenExpr binds (last $ init args)
204         ([], b) <- flattenExpr binds (last args)
205         res <- mkEqComparisons a b
206         return ([], res)
207       else
208         flattenApplicationExpr binds (CoreUtils.exprType app) f args
209   where
210     mkEqComparisons :: SignalMap -> SignalMap -> FlattenState SignalMap
211     mkEqComparisons a b = do
212       let zipped = zipValueMaps a b
213       Traversable.mapM mkEqComparison zipped
214
215     mkEqComparison :: (SignalId, SignalId) -> FlattenState SignalId
216     mkEqComparison (a, b) = do
217       -- Generate a signal to hold our result
218       res <- genSignalId SigInternal TysWiredIn.boolTy
219       -- Add a name hint to the signal
220       addNameHint ("s" ++ show a ++ "_eq_s" ++ show b) res
221       addDef (UncondDef (Right $ Eq a b) res)
222       return res
223
224     flattenBuildTupleExpr binds args = do
225       -- Flatten each of our args
226       flat_args <- (mapM (flattenExpr binds) args)
227       -- Check and split each of the arguments
228       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
229       let res = Tuple arg_ress
230       return ([], res)
231
232 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
233   (b_args, b_res) <- flattenExpr binds bexpr
234   if not (null b_args)
235     then
236       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
237     else do
238       let binds' = (b, Left b_res) : binds
239       -- Add name hints to the generated signals
240       let binder_name = Name.getOccString b
241       Traversable.mapM (addNameHint binder_name) b_res
242       flattenExpr binds' expr
243
244 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
245
246 flattenExpr binds expr@(Case scrut b _ alts) = do
247   -- TODO: Special casing for higher order functions
248   -- Flatten the scrutinee
249   (_, res) <- flattenExpr binds scrut
250   case alts of
251     -- TODO include b in the binds list
252     [alt] -> flattenSingleAltCaseExpr binds res b alt
253     -- Reverse the alternatives, so the __DEFAULT alternative ends up last
254     otherwise -> flattenMultipleAltCaseExpr binds res b (reverse alts)
255   where
256     flattenSingleAltCaseExpr ::
257       BindMap
258                                 -- A list of bindings in effect
259       -> SignalMap              -- The scrutinee
260       -> CoreBndr               -- The binder to bind the scrutinee to
261       -> CoreAlt                -- The single alternative
262       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
263
264     flattenSingleAltCaseExpr binds scrut b alt@(DataAlt datacon, bind_vars, expr) =
265       if DataCon.isTupleCon datacon
266         then do
267           -- Unpack the scrutinee (which must be a variable bound to a tuple) in
268           -- the existing bindings list and get the portname map for each of
269           -- it's elements.
270           let Tuple tuple_sigs = scrut
271           -- Add name hints to the returned signals
272           let binder_name = Name.getOccString b
273           Monad.zipWithM (\name  sigs -> Traversable.mapM (addNameHint $ Name.getOccString name) sigs) bind_vars tuple_sigs
274           -- Merge our existing binds with the new binds.
275           let binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
276           -- Expand the expression with the new binds list
277           flattenExpr binds' expr
278         else
279           if null bind_vars
280             then
281               -- DataAlts without arguments don't need processing
282               -- (flattenMultipleAltCaseExpr will have done this already).
283               flattenExpr binds expr
284             else
285               error $ "Dataconstructors other than tuple constructors cannot have binder arguments in case pattern of alternative: " ++ (showSDoc $ ppr alt)
286
287     flattenSingleAltCaseExpr binds _ _ alt@(DEFAULT, [], expr) =
288       flattenExpr binds expr
289       
290     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
291
292     flattenMultipleAltCaseExpr ::
293       BindMap
294                                 -- A list of bindings in effect
295       -> SignalMap              -- The scrutinee
296       -> CoreBndr               -- The binder to bind the scrutinee to
297       -> [CoreAlt]              -- The alternatives
298       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
299
300     flattenMultipleAltCaseExpr binds scrut b (a:a':alts) = do
301       (args, res) <- flattenSingleAltCaseExpr binds scrut b a
302       (args', res') <- flattenMultipleAltCaseExpr binds scrut b (a':alts)
303       case a of
304         (DataAlt datacon, bind_vars, expr) -> do
305           lit <- dataConToLiteral datacon
306           -- The scrutinee must be a single signal
307           let Single sig = scrut
308           -- Create a signal that contains a boolean
309           boolsigid <- genSignalId SigInternal TysWiredIn.boolTy
310           addNameHint ("s" ++ show sig ++ "_eq_" ++ lit) boolsigid
311           let expr = EqLit sig lit
312           addDef (UncondDef (Right expr) boolsigid)
313           -- Create conditional assignments of either args/res or
314           -- args'/res based on boolsigid, and return the result.
315           -- TODO: It seems this adds the name hint twice?
316           our_args <- Monad.zipWithM (mkConditionals boolsigid) args args'
317           our_res  <- mkConditionals boolsigid res res'
318           return (our_args, our_res)
319         otherwise ->
320           error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr a)
321       where
322         -- Select either the first or second signal map depending on the value
323         -- of the first argument (True == first map, False == second map)
324         mkConditionals :: SignalId -> SignalMap -> SignalMap -> FlattenState SignalMap
325         mkConditionals boolsigid true false = do
326           let zipped = zipValueMaps true false
327           Traversable.mapM (mkConditional boolsigid) zipped
328
329         mkConditional :: SignalId -> (SignalId, SignalId) -> FlattenState SignalId
330         mkConditional boolsigid (true, false) = do
331           -- Create a new signal (true and false should be identically typed,
332           -- so it doesn't matter which one we copy).
333           res <- copySignal true
334           addDef (CondDef boolsigid true false res)
335           return res
336
337     flattenMultipleAltCaseExpr binds scrut b (a:alts) =
338       flattenSingleAltCaseExpr binds scrut b a
339
340 flattenExpr _ expr = do
341   error $ "Unsupported expression: " ++ (showSDoc $ ppr expr)
342
343 -- | Flatten a normal application expression
344 flattenApplicationExpr binds ty f args = do
345   -- Find the function to call
346   let func = appToHsFunction ty f args
347   -- Flatten each of our args
348   flat_args <- (mapM (flattenExpr binds) args)
349   -- Check and split each of the arguments
350   let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
351   -- Generate signals for our result
352   res <- genSignals ty
353   -- Add name hints to the generated signals
354   let resname = Name.getOccString f ++ "_res"
355   Traversable.mapM (addNameHint resname) res
356   -- Create the function application
357   let app = FApp {
358     appFunc = func,
359     appArgs = arg_ress,
360     appRes  = res
361   }
362   addDef app
363   return ([], res)
364 -- | Check a flattened expression to see if it is valid to use as a
365 --   function argument. The first argument is the original expression for
366 --   use in the error message.
367 checkArg arg flat =
368   let (args, res) = flat in
369   if not (null args)
370     then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
371     else flat 
372
373 -- | Translates a dataconstructor without arguments to the corresponding
374 --   literal.
375 dataConToLiteral :: DataCon.DataCon -> FlattenState String
376 dataConToLiteral datacon = do
377   let tycon = DataCon.dataConTyCon datacon
378   let tyname = TyCon.tyConName tycon
379   case Name.getOccString tyname of
380     -- TODO: Do something more robust than string matching
381     "Bit"      -> do
382       let dcname = DataCon.dataConName datacon
383       let lit = case Name.getOccString dcname of "High" -> "'1'"; "Low" -> "'0'"
384       return lit
385     "Bool" -> do
386       let dcname = DataCon.dataConName datacon
387       let lit = case Name.getOccString dcname of "True" -> "true"; "False" -> "false"
388       return lit
389     otherwise ->
390       error $ "Literals of type " ++ (Name.getOccString tyname) ++ " not supported."
391
392 appToHsFunction ::
393   Type.Type       -- ^ The return type
394   -> Var.Var      -- ^ The function to call
395   -> [CoreExpr]   -- ^ The function arguments
396   -> HsFunction   -- ^ The needed HsFunction
397
398 appToHsFunction ty f args =
399   HsFunction hsname hsargs hsres
400   where
401     hsname = Name.getOccString f
402     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
403     hsres  = useAsPort (mkHsValueMap ty)
404
405 -- | Filters non-state signals and returns the state number and signal id for
406 --   state values.
407 filterState ::
408   SignalId                       -- | The signal id to look at
409   -> HsValueUse                  -- | How is this signal used?
410   -> Maybe (StateId, SignalId )  -- | The state num and signal id, if this
411                                  --   signal was used as state
412
413 filterState id (State num) = 
414   Just (num, id)
415 filterState _ _ = Nothing
416
417 -- | Returns a list of the state number and signal id of all used-as-state
418 --   signals in the given maps.
419 stateList ::
420   HsUseMap
421   -> (SignalMap)
422   -> [(StateId, SignalId)]
423
424 stateList uses signals =
425     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
426   
427 -- vim: set ts=8 sw=2 sts=2 expandtab: