Make register_bank work, with a bunch of changes.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import Control.Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified TyCon
11 import qualified CoreUtils
12 import qualified TysWiredIn
13 import qualified IdInfo
14 import qualified Data.Traversable as Traversable
15 import qualified Data.Foldable as Foldable
16 import Control.Applicative
17 import Outputable ( showSDoc, ppr )
18 import qualified Control.Monad.State as State
19
20 import HsValueMap
21 import TranslatorTypes
22 import FlattenTypes
23
24 -- Extract the arguments from a data constructor application (that is, the
25 -- normal args, leaving out the type args).
26 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
27 dataConAppArgs dc args =
28     drop tycount args
29   where
30     tycount = length $ DataCon.dataConAllTyVars dc
31
32 genSignals ::
33   Type.Type
34   -> FlattenState SignalMap
35
36 genSignals ty =
37   -- First generate a map with the right structure containing the types, and
38   -- generate signals for each of them.
39   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
40
41 -- | Marks a signal as the given SigUse, if its id is in the list of id's
42 --   given.
43 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
44 markSignals use ids (id, info) =
45   (id, info')
46   where
47     info' = if id `elem` ids then info { sigUse = use} else info
48
49 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
50 markSignal use id = markSignals use [id]
51
52 -- | Flatten a haskell function
53 flattenFunction ::
54   HsFunction                      -- ^ The function to flatten
55   -> CoreBind                     -- ^ The function value
56   -> FlatFunction                 -- ^ The resulting flat function
57
58 flattenFunction _ (Rec _) = error "Recursive binders not supported"
59 flattenFunction hsfunc bind@(NonRec var expr) =
60   FlatFunction args res defs sigs
61   where
62     init_state        = ([], [], 0)
63     (fres, end_state) = State.runState (flattenTopExpr hsfunc expr) init_state
64     (defs, sigs, _)   = end_state
65     (args, res)       = fres
66
67 flattenTopExpr ::
68   HsFunction
69   -> CoreExpr
70   -> FlattenState ([SignalMap], SignalMap)
71
72 flattenTopExpr hsfunc expr = do
73   -- Flatten the expression
74   (args, res) <- flattenExpr [] expr
75   
76   -- Join the signal ids and uses together
77   let zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
78   let zipped_res = zipValueMaps res (hsFuncRes hsfunc)
79   -- Set the signal uses for each argument / result, possibly updating
80   -- argument or result signals.
81   args' <- mapM (Traversable.mapM $ hsUseToSigUse args_use) zipped_args
82   res' <- Traversable.mapM (hsUseToSigUse res_use) zipped_res
83   return (args', res')
84   where
85     args_use Port = SigPortIn
86     args_use (State n) = SigStateOld n
87     res_use Port = SigPortOut
88     res_use (State n) = SigStateNew n
89
90
91 hsUseToSigUse :: 
92   (HsValueUse -> SigUse)      -- ^ A function to actually map the use value
93   -> (SignalId, HsValueUse)   -- ^ The signal to look at and its use
94   -> FlattenState SignalId    -- ^ The resulting signal. This is probably the
95                               --   same as the input, but it could be different.
96 hsUseToSigUse f (id, use) = do
97   info <- getSignalInfo id
98   id' <- case sigUse info of 
99     -- Internal signals can be marked as different uses freely.
100     SigInternal -> do
101       return id
102     -- Signals that already have another use, must be duplicated before
103     -- marking. This prevents signals mapping to the same input or output
104     -- port or state variables and ports overlapping, etc.
105     otherwise -> do
106       duplicateSignal id
107   setSignalInfo id' (info { sigUse = f use})
108   return id'
109
110 -- | Creates a new internal signal with the same type as the given signal
111 copySignal :: SignalId -> FlattenState SignalId
112 copySignal id = do
113   -- Find the type of the original signal
114   info <- getSignalInfo id
115   let ty = sigTy info
116   -- Generate a new signal (which is SigInternal for now, that will be
117   -- sorted out later on).
118   genSignalId SigInternal ty
119
120 -- | Duplicate the given signal, assigning its value to the new signal.
121 --   Returns the new signal id.
122 duplicateSignal :: SignalId -> FlattenState SignalId
123 duplicateSignal id = do
124   -- Create a new signal
125   id' <- copySignal id
126   -- Assign the old signal to the new signal
127   addDef $ UncondDef (Left id) id'
128   -- Replace the signal with the new signal
129   return id'
130         
131 flattenExpr ::
132   BindMap
133   -> CoreExpr
134   -> FlattenState ([SignalMap], SignalMap)
135
136 flattenExpr binds lam@(Lam b expr) = do
137   -- Find the type of the binder
138   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
139   -- Create signal names for the binder
140   defs <- genSignals arg_ty
141   let binds' = (b, Left defs):binds
142   (args, res) <- flattenExpr binds' expr
143   return (defs : args, res)
144
145 flattenExpr binds var@(Var id) =
146   case Var.globalIdVarDetails id of
147     IdInfo.NotGlobalId ->
148       let 
149         bind = Maybe.fromMaybe
150           (error $ "Local value " ++ Name.getOccString id ++ " is unknown")
151           (lookup id binds) 
152       in
153         case bind of
154           Left sig_use -> return ([], sig_use)
155           Right _ -> error "Higher order functions not supported."
156     IdInfo.DataConWorkId datacon -> do
157       lit <- dataConToLiteral datacon
158       let ty = CoreUtils.exprType var
159       id <- genSignalId SigInternal ty
160       addDef (UncondDef (Right $ Literal lit) id)
161       return ([], Single id)
162     otherwise ->
163       error $ "Ids other than local vars and dataconstructors not supported: " ++ (showSDoc $ ppr id)
164
165 flattenExpr binds app@(App _ _) = do
166   -- Is this a data constructor application?
167   case CoreUtils.exprIsConApp_maybe app of
168     -- Is this a tuple construction?
169     Just (dc, args) -> if DataCon.isTupleCon dc 
170       then
171         flattenBuildTupleExpr binds (dataConAppArgs dc args)
172       else
173         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
174     otherwise ->
175       -- Normal function application
176       let ((Var f), args) = collectArgs app in
177       let fname = Name.getOccString f in
178       if fname == "fst" || fname == "snd" then do
179         (args', Tuple [a, b]) <- flattenExpr binds (last args)
180         return (args', if fname == "fst" then a else b)
181       else if fname == "patError" then do
182         -- This is essentially don't care, since the program will error out
183         -- here. We'll just define undriven signals here.
184         let (argtys, resty) = Type.splitFunTys $ CoreUtils.exprType app
185         args <- mapM genSignals argtys
186         res <- genSignals resty
187         return (args, res)
188       else if fname == "==" then do
189         -- Flatten the last two arguments (this skips the type arguments)
190         ([], a) <- flattenExpr binds (last $ init args)
191         ([], b) <- flattenExpr binds (last args)
192         res <- mkEqComparisons a b
193         return ([], res)
194       else
195         flattenApplicationExpr binds (CoreUtils.exprType app) f args
196   where
197     mkEqComparisons :: SignalMap -> SignalMap -> FlattenState SignalMap
198     mkEqComparisons a b = do
199       let zipped = zipValueMaps a b
200       Traversable.mapM mkEqComparison zipped
201
202     mkEqComparison :: (SignalId, SignalId) -> FlattenState SignalId
203     mkEqComparison (a, b) = do
204       -- Generate a signal to hold our result
205       res <- genSignalId SigInternal TysWiredIn.boolTy
206       addDef (UncondDef (Right $ Eq a b) res)
207       return res
208
209     flattenBuildTupleExpr binds args = do
210       -- Flatten each of our args
211       flat_args <- (State.mapM (flattenExpr binds) args)
212       -- Check and split each of the arguments
213       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
214       let res = Tuple arg_ress
215       return ([], res)
216
217     -- | Flatten a normal application expression
218     flattenApplicationExpr binds ty f args = do
219       -- Find the function to call
220       let func = appToHsFunction ty f args
221       -- Flatten each of our args
222       flat_args <- (State.mapM (flattenExpr binds) args)
223       -- Check and split each of the arguments
224       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
225       -- Generate signals for our result
226       res <- genSignals ty
227       -- Create the function application
228       let app = FApp {
229         appFunc = func,
230         appArgs = arg_ress,
231         appRes  = res
232       }
233       addDef app
234       return ([], res)
235     -- | Check a flattened expression to see if it is valid to use as a
236     --   function argument. The first argument is the original expression for
237     --   use in the error message.
238     checkArg arg flat =
239       let (args, res) = flat in
240       if not (null args)
241         then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
242         else flat 
243
244 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
245   (b_args, b_res) <- flattenExpr binds bexpr
246   if not (null b_args)
247     then
248       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
249     else
250       let binds' = (b, Left b_res) : binds in
251       flattenExpr binds' expr
252
253 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
254
255 flattenExpr binds expr@(Case scrut b _ alts) = do
256   -- TODO: Special casing for higher order functions
257   -- Flatten the scrutinee
258   (_, res) <- flattenExpr binds scrut
259   case alts of
260     [alt] -> flattenSingleAltCaseExpr binds res b alt
261     otherwise -> flattenMultipleAltCaseExpr binds res b alts
262   where
263     flattenSingleAltCaseExpr ::
264       BindMap
265                                 -- A list of bindings in effect
266       -> SignalMap              -- The scrutinee
267       -> CoreBndr               -- The binder to bind the scrutinee to
268       -> CoreAlt                -- The single alternative
269       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
270
271     flattenSingleAltCaseExpr binds scrut b alt@(DataAlt datacon, bind_vars, expr) =
272       if DataCon.isTupleCon datacon
273         then
274           let
275             -- Unpack the scrutinee (which must be a variable bound to a tuple) in
276             -- the existing bindings list and get the portname map for each of
277             -- it's elements.
278             Tuple tuple_sigs = scrut
279             -- TODO include b in the binds list
280             -- Merge our existing binds with the new binds.
281             binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
282           in
283             -- Expand the expression with the new binds list
284             flattenExpr binds' expr
285         else
286           if null bind_vars
287             then
288               -- DataAlts without arguments don't need processing
289               -- (flattenMultipleAltCaseExpr will have done this already).
290               flattenExpr binds expr
291             else
292               error $ "Dataconstructors other than tuple constructors cannot have binder arguments in case pattern of alternative: " ++ (showSDoc $ ppr alt)
293     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
294
295     flattenMultipleAltCaseExpr ::
296       BindMap
297                                 -- A list of bindings in effect
298       -> SignalMap              -- The scrutinee
299       -> CoreBndr               -- The binder to bind the scrutinee to
300       -> [CoreAlt]              -- The alternatives
301       -> FlattenState ( [SignalMap], SignalMap) -- See expandExpr
302
303     flattenMultipleAltCaseExpr binds scrut b (a:a':alts) = do
304       (args, res) <- flattenSingleAltCaseExpr binds scrut b a
305       (args', res') <- flattenMultipleAltCaseExpr binds scrut b (a':alts)
306       case a of
307         (DataAlt datacon, bind_vars, expr) -> do
308           lit <- dataConToLiteral datacon
309           -- The scrutinee must be a single signal
310           let Single sig = scrut
311           -- Create a signal that contains a boolean
312           boolsigid <- genSignalId SigInternal TysWiredIn.boolTy
313           let expr = EqLit sig lit
314           addDef (UncondDef (Right expr) boolsigid)
315           -- Create conditional assignments of either args/res or
316           -- args'/res based on boolsigid, and return the result.
317           our_args <- zipWithM (mkConditionals boolsigid) args args'
318           our_res  <- mkConditionals boolsigid res res'
319           return (our_args, our_res)
320         otherwise ->
321           error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr a)
322       where
323         -- Select either the first or second signal map depending on the value
324         -- of the first argument (True == first map, False == second map)
325         mkConditionals :: SignalId -> SignalMap -> SignalMap -> FlattenState SignalMap
326         mkConditionals boolsigid true false = do
327           let zipped = zipValueMaps true false
328           Traversable.mapM (mkConditional boolsigid) zipped
329
330         mkConditional :: SignalId -> (SignalId, SignalId) -> FlattenState SignalId
331         mkConditional boolsigid (true, false) = do
332           -- Create a new signal (true and false should be identically typed,
333           -- so it doesn't matter which one we copy).
334           res <- copySignal true
335           addDef (CondDef boolsigid true false res)
336           return res
337
338     flattenMultipleAltCaseExpr binds scrut b (a:alts) =
339       flattenSingleAltCaseExpr binds scrut b a
340
341 flattenExpr _ expr = do
342   error $ "Unsupported expression: " ++ (showSDoc $ ppr expr)
343
344 -- | Translates a dataconstructor without arguments to the corresponding
345 --   literal.
346 dataConToLiteral :: DataCon.DataCon -> FlattenState String
347 dataConToLiteral datacon = do
348   let tycon = DataCon.dataConTyCon datacon
349   let tyname = TyCon.tyConName tycon
350   case Name.getOccString tyname of
351     -- TODO: Do something more robust than string matching
352     "Bit"      -> do
353       let dcname = DataCon.dataConName datacon
354       let lit = case Name.getOccString dcname of "High" -> "'1'"; "Low" -> "'0'"
355       return lit
356     "Bool" -> do
357       let dcname = DataCon.dataConName datacon
358       let lit = case Name.getOccString dcname of "True" -> "true"; "False" -> "false"
359       return lit
360     otherwise ->
361       error $ "Literals of type " ++ (Name.getOccString tyname) ++ " not supported."
362
363 appToHsFunction ::
364   Type.Type       -- ^ The return type
365   -> Var.Var      -- ^ The function to call
366   -> [CoreExpr]   -- ^ The function arguments
367   -> HsFunction   -- ^ The needed HsFunction
368
369 appToHsFunction ty f args =
370   HsFunction hsname hsargs hsres
371   where
372     hsname = Name.getOccString f
373     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
374     hsres  = useAsPort (mkHsValueMap ty)
375
376 -- | Filters non-state signals and returns the state number and signal id for
377 --   state values.
378 filterState ::
379   SignalId                       -- | The signal id to look at
380   -> HsValueUse                  -- | How is this signal used?
381   -> Maybe (StateId, SignalId )  -- | The state num and signal id, if this
382                                  --   signal was used as state
383
384 filterState id (State num) = 
385   Just (num, id)
386 filterState _ _ = Nothing
387
388 -- | Returns a list of the state number and signal id of all used-as-state
389 --   signals in the given maps.
390 stateList ::
391   HsUseMap
392   -> (SignalMap)
393   -> [(StateId, SignalId)]
394
395 stateList uses signals =
396     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
397   
398 -- | Returns pairs of signals that should be mapped to state in this function.
399 getOwnStates ::
400   HsFunction                      -- | The function to look at
401   -> FlatFunction                 -- | The function to look at
402   -> [(StateId, SignalInfo, SignalInfo)]   
403         -- | The state signals. The first is the state number, the second the
404         --   signal to assign the current state to, the last is the signal
405         --   that holds the new state.
406
407 getOwnStates hsfunc flatfunc =
408   [(old_num, old_info, new_info) 
409     | (old_num, old_info) <- args_states
410     , (new_num, new_info) <- res_states
411     , old_num == new_num]
412   where
413     sigs = flat_sigs flatfunc
414     -- Translate args and res to lists of (statenum, sigid)
415     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
416     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
417     -- Replace the second tuple element with the corresponding SignalInfo
418     args_states = map (Arrow.second $ signalInfo sigs) args
419     res_states = map (Arrow.second $ signalInfo sigs) res
420
421     
422 -- vim: set ts=8 sw=2 sts=2 expandtab: