Remove type parameterisation of SignalMap.
[matthijs/master-project/cλash.git] / Flatten.hs
1 module Flatten where
2 import CoreSyn
3 import Control.Monad
4 import qualified Var
5 import qualified Type
6 import qualified Name
7 import qualified Maybe
8 import qualified Control.Arrow as Arrow
9 import qualified DataCon
10 import qualified CoreUtils
11 import qualified Data.Traversable as Traversable
12 import qualified Data.Foldable as Foldable
13 import Control.Applicative
14 import Outputable ( showSDoc, ppr )
15 import qualified Control.Monad.State as State
16
17 import HsValueMap
18 import TranslatorTypes
19 import FlattenTypes
20
21 -- Extract the arguments from a data constructor application (that is, the
22 -- normal args, leaving out the type args).
23 dataConAppArgs :: DataCon.DataCon -> [CoreExpr] -> [CoreExpr]
24 dataConAppArgs dc args =
25     drop tycount args
26   where
27     tycount = length $ DataCon.dataConAllTyVars dc
28
29 genSignals ::
30   Type.Type
31   -> FlattenState SignalMap
32
33 genSignals ty =
34   -- First generate a map with the right structure containing the types, and
35   -- generate signals for each of them.
36   Traversable.mapM (\ty -> genSignalId SigInternal ty) (mkHsValueMap ty)
37
38 -- | Marks a signal as the given SigUse, if its id is in the list of id's
39 --   given.
40 markSignals :: SigUse -> [SignalId] -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
41 markSignals use ids (id, info) =
42   (id, info')
43   where
44     info' = if id `elem` ids then info { sigUse = use} else info
45
46 markSignal :: SigUse -> SignalId -> (SignalId, SignalInfo) -> (SignalId, SignalInfo)
47 markSignal use id = markSignals use [id]
48
49 -- | Flatten a haskell function
50 flattenFunction ::
51   HsFunction                      -- ^ The function to flatten
52   -> CoreBind                     -- ^ The function value
53   -> FlatFunction                 -- ^ The resulting flat function
54
55 flattenFunction _ (Rec _) = error "Recursive binders not supported"
56 flattenFunction hsfunc bind@(NonRec var expr) =
57   FlatFunction args res apps conds sigs''''
58   where
59     init_state        = ([], [], [], 0)
60     (fres, end_state) = State.runState (flattenExpr [] expr) init_state
61     (apps, conds, sigs, _)  = end_state
62     (args, res)       = fres
63     arg_ports         = concat (map Foldable.toList args)
64     res_ports         = Foldable.toList res
65     -- Mark args and result signals as input and output ports resp.
66     sigs'             = fmap (markSignals SigPortIn arg_ports) sigs
67     sigs''            = fmap (markSignals SigPortOut res_ports) sigs'
68     -- Mark args and result states as old and new state resp.
69     args_states       = concat $ zipWith stateList (hsFuncArgs hsfunc) args
70     sigs'''           = foldl (\s (num, id) -> map (markSignal (SigStateOld num) id) s) sigs'' args_states
71     res_states        = stateList (hsFuncRes hsfunc) res
72     sigs''''          = foldl (\s (num, id) -> map (markSignal (SigStateNew num) id) s) sigs''' res_states
73
74 flattenExpr ::
75   BindMap
76   -> CoreExpr
77   -> FlattenState ([SignalMap], SignalMap)
78
79 flattenExpr binds lam@(Lam b expr) = do
80   -- Find the type of the binder
81   let (arg_ty, _) = Type.splitFunTy (CoreUtils.exprType lam)
82   -- Create signal names for the binder
83   defs <- genSignals arg_ty
84   let binds' = (b, Left defs):binds
85   (args, res) <- flattenExpr binds' expr
86   return (defs : args, res)
87
88 flattenExpr binds (Var id) =
89   case bind of
90     Left sig_use -> return ([], sig_use)
91     Right _ -> error "Higher order functions not supported."
92   where
93     bind = Maybe.fromMaybe
94       (error $ "Argument " ++ Name.getOccString id ++ "is unknown")
95       (lookup id binds)
96
97 flattenExpr binds app@(App _ _) = do
98   -- Is this a data constructor application?
99   case CoreUtils.exprIsConApp_maybe app of
100     -- Is this a tuple construction?
101     Just (dc, args) -> if DataCon.isTupleCon dc 
102       then
103         flattenBuildTupleExpr binds (dataConAppArgs dc args)
104       else
105         error $ "Data constructors other than tuples not supported: " ++ (showSDoc $ ppr app)
106     otherwise ->
107       -- Normal function application
108       let ((Var f), args) = collectArgs app in
109       flattenApplicationExpr binds (CoreUtils.exprType app) f args
110   where
111     flattenBuildTupleExpr binds args = do
112       -- Flatten each of our args
113       flat_args <- (State.mapM (flattenExpr binds) args)
114       -- Check and split each of the arguments
115       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
116       let res = Tuple arg_ress
117       return ([], res)
118
119     -- | Flatten a normal application expression
120     flattenApplicationExpr binds ty f args = do
121       -- Find the function to call
122       let func = appToHsFunction ty f args
123       -- Flatten each of our args
124       flat_args <- (State.mapM (flattenExpr binds) args)
125       -- Check and split each of the arguments
126       let (_, arg_ress) = unzip (zipWith checkArg args flat_args)
127       -- Generate signals for our result
128       res <- genSignals ty
129       -- Create the function application
130       let app = FApp {
131         appFunc = func,
132         appArgs = arg_ress,
133         appRes  = res
134       }
135       addApp app
136       return ([], res)
137     -- | Check a flattened expression to see if it is valid to use as a
138     --   function argument. The first argument is the original expression for
139     --   use in the error message.
140     checkArg arg flat =
141       let (args, res) = flat in
142       if not (null args)
143         then error $ "Passing lambda expression or function as a function argument not supported: " ++ (showSDoc $ ppr arg)
144         else flat 
145
146 flattenExpr binds l@(Let (NonRec b bexpr) expr) = do
147   (b_args, b_res) <- flattenExpr binds bexpr
148   if not (null b_args)
149     then
150       error $ "Higher order functions not supported in let expression: " ++ (showSDoc $ ppr l)
151     else
152       let binds' = (b, Left b_res) : binds in
153       flattenExpr binds' expr
154
155 flattenExpr binds l@(Let (Rec _) _) = error $ "Recursive let definitions not supported: " ++ (showSDoc $ ppr l)
156
157 flattenExpr binds expr@(Case (Var v) b _ alts) =
158   case alts of
159     [alt] -> flattenSingleAltCaseExpr binds v b alt
160     otherwise -> error $ "Multiple alternative case expression not supported: " ++ (showSDoc $ ppr expr)
161   where
162     flattenSingleAltCaseExpr ::
163       BindMap
164                                 -- A list of bindings in effect
165       -> Var.Var                -- The scrutinee
166       -> CoreBndr               -- The binder to bind the scrutinee to
167       -> CoreAlt                -- The single alternative
168       -> FlattenState ( [SignalMap], SignalMap)
169                                            -- See expandExpr
170     flattenSingleAltCaseExpr binds v b alt@(DataAlt datacon, bind_vars, expr) =
171       if not (DataCon.isTupleCon datacon) 
172         then
173           error $ "Dataconstructors other than tuple constructors not supported in case pattern of alternative: " ++ (showSDoc $ ppr alt)
174         else
175           let
176             -- Lookup the scrutinee (which must be a variable bound to a tuple) in
177             -- the existing bindings list and get the portname map for each of
178             -- it's elements.
179             Left (Tuple tuple_sigs) = Maybe.fromMaybe 
180               (error $ "Case expression uses unknown scrutinee " ++ Name.getOccString v)
181               (lookup v binds)
182             -- TODO include b in the binds list
183             -- Merge our existing binds with the new binds.
184             binds' = (zip bind_vars (map Left tuple_sigs)) ++ binds 
185           in
186             -- Expand the expression with the new binds list
187             flattenExpr binds' expr
188     flattenSingleAltCaseExpr _ _ _ alt = error $ "Case patterns other than data constructors not supported in case alternative: " ++ (showSDoc $ ppr alt)
189
190
191       
192 flattenExpr _ _ = do
193   return ([], Tuple [])
194
195 appToHsFunction ::
196   Type.Type       -- ^ The return type
197   -> Var.Var      -- ^ The function to call
198   -> [CoreExpr]   -- ^ The function arguments
199   -> HsFunction   -- ^ The needed HsFunction
200
201 appToHsFunction ty f args =
202   HsFunction hsname hsargs hsres
203   where
204     hsname = Name.getOccString f
205     hsargs = map (useAsPort . mkHsValueMap . CoreUtils.exprType) args
206     hsres  = useAsPort (mkHsValueMap ty)
207
208 -- | Filters non-state signals and returns the state number and signal id for
209 --   state values.
210 filterState ::
211   SignalId                       -- | The signal id to look at
212   -> HsValueUse                  -- | How is this signal used?
213   -> Maybe (Int, SignalId )      -- | The state num and signal id, if this
214                                  --   signal was used as state
215
216 filterState id (State num) = 
217   Just (num, id)
218 filterState _ _ = Nothing
219
220 -- | Returns a list of the state number and signal id of all used-as-state
221 --   signals in the given maps.
222 stateList ::
223   HsUseMap
224   -> (SignalMap)
225   -> [(Int, SignalId)]
226
227 stateList uses signals =
228     Maybe.catMaybes $ Foldable.toList $ zipValueMapsWith filterState signals uses
229   
230 -- | Returns pairs of signals that should be mapped to state in this function.
231 getOwnStates ::
232   HsFunction                      -- | The function to look at
233   -> FlatFunction                 -- | The function to look at
234   -> [(Int, SignalInfo, SignalInfo)]   
235         -- | The state signals. The first is the state number, the second the
236         --   signal to assign the current state to, the last is the signal
237         --   that holds the new state.
238
239 getOwnStates hsfunc flatfunc =
240   [(old_num, old_info, new_info) 
241     | (old_num, old_info) <- args_states
242     , (new_num, new_info) <- res_states
243     , old_num == new_num]
244   where
245     sigs = flat_sigs flatfunc
246     -- Translate args and res to lists of (statenum, sigid)
247     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
248     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
249     -- Replace the second tuple element with the corresponding SignalInfo
250     args_states = map (Arrow.second $ signalInfo sigs) args
251     res_states = map (Arrow.second $ signalInfo sigs) res
252
253     
254 -- vim: set ts=8 sw=2 sts=2 expandtab: