Remove createArchitecture from the VHDLState Monad.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import GHC hiding (loadModule, sigName)
5 import CoreSyn
6 import qualified CoreUtils
7 import qualified Var
8 import qualified Type
9 import qualified TyCon
10 import qualified DataCon
11 import qualified Maybe
12 import qualified Module
13 import qualified Control.Monad.State as State
14 import qualified Data.Foldable as Foldable
15 import Name
16 import qualified Data.Map as Map
17 import Data.Generics
18 import NameEnv ( lookupNameEnv )
19 import qualified HscTypes
20 import HscTypes ( cm_binds, cm_types )
21 import MonadUtils ( liftIO )
22 import Outputable ( showSDoc, ppr )
23 import GHC.Paths ( libdir )
24 import DynFlags ( defaultDynFlags )
25 import List ( find )
26 import qualified List
27 import qualified Monad
28
29 -- The following modules come from the ForSyDe project. They are really
30 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
31 -- ForSyDe to get access to these modules.
32 import qualified ForSyDe.Backend.VHDL.AST as AST
33 import qualified ForSyDe.Backend.VHDL.Ppr
34 import qualified ForSyDe.Backend.VHDL.FileIO
35 import qualified ForSyDe.Backend.Ppr
36 -- This is needed for rendering the pretty printed VHDL
37 import Text.PrettyPrint.HughesPJ (render)
38
39 import TranslatorTypes
40 import HsValueMap
41 import Pretty
42 import Flatten
43 import FlattenTypes
44 import VHDLTypes
45 import qualified VHDL
46
47 main = do
48   makeVHDL "Alu.hs" "register_bank" True
49
50 makeVHDL :: String -> String -> Bool -> IO ()
51 makeVHDL filename name stateful = do
52   -- Load the module
53   core <- loadModule filename
54   -- Translate to VHDL
55   vhdl <- moduleToVHDL core [(name, stateful)]
56   -- Write VHDL to file
57   let dir = "../vhdl/vhdl/" ++ name ++ "/"
58   mapM (writeVHDL dir) vhdl
59   return ()
60
61 -- | Show the core structure of the given binds in the given file.
62 listBind :: String -> String -> IO ()
63 listBind filename name = do
64   core <- loadModule filename
65   let binds = findBinds core [name]
66   putStr "\n"
67   putStr $ prettyShow binds
68   putStr "\n\n"
69   putStr $ showSDoc $ ppr binds
70   putStr "\n\n"
71
72 -- | Translate the binds with the given names from the given core module to
73 --   VHDL. The Bool in the tuple makes the function stateful (True) or
74 --   stateless (False).
75 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [AST.DesignFile]
76 moduleToVHDL core list = do
77   let (names, statefuls) = unzip list
78   --liftIO $ putStr $ prettyShow (cm_binds core)
79   let binds = findBinds core names
80   --putStr $ prettyShow binds
81   -- Turn bind into VHDL
82   let (vhdl, sess) = State.runState (mkVHDL binds statefuls) (VHDLSession core 0 Map.empty)
83   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr) vhdl
84   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
85   return vhdl
86
87   where
88     -- Turns the given bind into VHDL
89     mkVHDL binds statefuls = do
90       -- Add the builtin functions
91       mapM addBuiltIn builtin_funcs
92       -- Create entities and architectures for them
93       Monad.zipWithM processBind statefuls binds
94       modFuncMap $ Map.map (\fdata -> fdata {flatFunc = fmap nameFlatFunction (flatFunc fdata)})
95       modFuncMap $ Map.mapWithKey (\hsfunc fdata -> fdata {funcEntity = VHDL.createEntity hsfunc fdata})
96       funcs <- getFuncMap
97       modFuncMap $ Map.mapWithKey (\hsfunc fdata -> fdata {funcArch = VHDL.createArchitecture funcs hsfunc fdata})
98       funcs <- getFuncs
99       return $ VHDL.getDesignFiles (map snd funcs)
100
101 -- | Write the given design file to a file inside the given dir
102 --   The first library unit in the designfile must be an entity, whose name
103 --   will be used as a filename.
104 writeVHDL :: String -> AST.DesignFile -> IO ()
105 writeVHDL dir vhdl = do
106   -- Create the dir if needed
107   exists <- Directory.doesDirectoryExist dir
108   Monad.unless exists $ Directory.createDirectory dir
109   -- Find the filename
110   let AST.DesignFile _ (u:us) = vhdl
111   let AST.LUEntity (AST.EntityDec id _) = u
112   let fname = dir ++ AST.fromVHDLId id ++ ".vhdl"
113   -- Write the file
114   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
115
116 -- | Loads the given file and turns it into a core module.
117 loadModule :: String -> IO HscTypes.CoreModule
118 loadModule filename =
119   defaultErrorHandler defaultDynFlags $ do
120     runGhc (Just libdir) $ do
121       dflags <- getSessionDynFlags
122       setSessionDynFlags dflags
123       --target <- guessTarget "adder.hs" Nothing
124       --liftIO (print (showSDoc (ppr (target))))
125       --liftIO $ printTarget target
126       --setTargets [target]
127       --load LoadAllTargets
128       --core <- GHC.compileToCoreSimplified "Adders.hs"
129       core <- GHC.compileToCoreSimplified filename
130       return core
131
132 -- | Extracts the named binds from the given module.
133 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
134 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
135
136 -- | Extract a named bind from the given list of binds
137 findBind :: [CoreBind] -> String -> Maybe CoreBind
138 findBind binds lookfor =
139   -- This ignores Recs and compares the name of the bind with lookfor,
140   -- disregarding any namespaces in OccName and extra attributes in Name and
141   -- Var.
142   find (\b -> case b of 
143     Rec l -> False
144     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
145   ) binds
146
147 -- | Processes the given bind as a top level bind.
148 processBind ::
149   Bool                       -- ^ Should this be stateful function?
150   -> CoreBind                -- ^ The bind to process
151   -> VHDLState ()
152
153 processBind _ (Rec _) = error "Recursive binders not supported"
154 processBind stateful bind@(NonRec var expr) = do
155   -- Create the function signature
156   let ty = CoreUtils.exprType expr
157   let hsfunc = mkHsFunction var ty stateful
158   flattenBind hsfunc bind
159
160 -- | Flattens the given bind into the given signature and adds it to the
161 --   session. Then (recursively) finds any functions it uses and does the same
162 --   with them.
163 flattenBind ::
164   HsFunction                         -- The signature to flatten into
165   -> CoreBind                        -- The bind to flatten
166   -> VHDLState ()
167
168 flattenBind _ (Rec _) = error "Recursive binders not supported"
169
170 flattenBind hsfunc bind@(NonRec var expr) = do
171   -- Add the function to the session
172   addFunc hsfunc
173   -- Flatten the function
174   let flatfunc = flattenFunction hsfunc bind
175   -- Propagate state variables
176   let flatfunc' = propagateState hsfunc flatfunc
177   -- Store the flat function in the session
178   setFlatFunc hsfunc flatfunc'
179   -- Flatten any functions used
180   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
181   State.mapM resolvFunc used_hsfuncs
182   return ()
183
184 -- | Decide which incoming state variables will become state in the
185 --   given function, and which will be propagate to other applied
186 --   functions.
187 propagateState ::
188   HsFunction
189   -> FlatFunction
190   -> FlatFunction
191
192 propagateState hsfunc flatfunc =
193     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
194   where
195     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
196     states' = zip olds news
197     -- Find all signals used by all sigdefs
198     uses = concatMap sigDefUses (flat_defs flatfunc)
199     -- Find all signals that are used more than once (is there a
200     -- prettier way to do this?)
201     multiple_uses = uses List.\\ (List.nub uses)
202     -- Find the states whose "old state" signal is used only once
203     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
204     -- See if these single use states can be propagated
205     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
206     substate_sigs = concat substate_sigss
207     -- Mark any propagated state signals as SigSubState
208     sigs' = map 
209       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
210       (flat_sigs flatfunc)
211
212 -- | Propagate the state into a single function application.
213 propagateState' ::
214   [(SignalId, SignalId)]
215                       -- ^ TODO
216   -> SigDef           -- ^ The SigDef to process.
217   -> ([SignalId], SigDef) 
218                       -- ^ Any signal ids that should become substates,
219                       --   and the resulting application.
220
221 propagateState' states def =
222     if (is_FApp def) then
223       (our_old ++ our_new, def {appFunc = hsfunc'})
224     else
225       ([], def)
226   where
227     hsfunc = appFunc def
228     args = appArgs def
229     res = appRes def
230     our_states = filter our_state states
231     -- A state signal belongs in this function if the old state is
232     -- passed in, and the new state returned
233     our_state (old, new) =
234       any (old `Foldable.elem`) args
235       && new `Foldable.elem` res
236     (our_old, our_new) = unzip our_states
237     -- Mark the result
238     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
239     res' = fmap (mark_state (zip our_new [0..])) zipped_res
240     -- Mark the args
241     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
242     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
243     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
244
245     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
246     mark_state states (id, use) =
247       case lookup id states of
248         Nothing -> use
249         Just state_id -> State state_id
250
251 -- | Returns pairs of signals that should be mapped to state in this function.
252 getStateSignals ::
253   HsFunction                      -- | The function to look at
254   -> FlatFunction                 -- | The function to look at
255   -> [(SignalId, SignalId)]   
256         -- | TODO The state signals. The first is the state number, the second the
257         --   signal to assign the current state to, the last is the signal
258         --   that holds the new state.
259
260 getStateSignals hsfunc flatfunc =
261   [(old_id, new_id) 
262     | (old_num, old_id) <- args
263     , (new_num, new_id) <- res
264     , old_num == new_num]
265   where
266     sigs = flat_sigs flatfunc
267     -- Translate args and res to lists of (statenum, sigid)
268     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
269     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
270     
271 -- | Find the given function, flatten it and add it to the session. Then
272 --   (recursively) do the same for any functions used.
273 resolvFunc ::
274   HsFunction        -- | The function to look for
275   -> VHDLState ()
276
277 resolvFunc hsfunc = do
278   -- See if the function is already known
279   func <- getFunc hsfunc
280   case func of
281     -- Already known, do nothing
282     Just _ -> do
283       return ()
284     -- New function, resolve it
285     Nothing -> do
286       -- Get the current module
287       core <- getModule
288       -- Find the named function
289       let bind = findBind (cm_binds core) name
290       case bind of
291         Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
292         Just b  -> flattenBind hsfunc b
293   where
294     name = hsFuncName hsfunc
295
296 -- | Translate a top level function declaration to a HsFunction. i.e., which
297 --   interface will be provided by this function. This function essentially
298 --   defines the "calling convention" for hardware models.
299 mkHsFunction ::
300   Var.Var         -- ^ The function defined
301   -> Type         -- ^ The function type (including arguments!)
302   -> Bool         -- ^ Is this a stateful function?
303   -> HsFunction   -- ^ The resulting HsFunction
304
305 mkHsFunction f ty stateful=
306   HsFunction hsname hsargs hsres
307   where
308     hsname  = getOccString f
309     (arg_tys, res_ty) = Type.splitFunTys ty
310     (hsargs, hsres) = 
311       if stateful 
312       then
313         let
314           -- The last argument must be state
315           state_ty = last arg_tys
316           state    = useAsState (mkHsValueMap state_ty)
317           -- All but the last argument are inports
318           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
319           hsargs   = inports ++ [state]
320           hsres    = case splitTupleType res_ty of
321             -- Result type must be a two tuple (state, ports)
322             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
323               then
324                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
325               else
326                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
327             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
328         in
329           (hsargs, hsres)
330       else
331         -- Just use everything as a port
332         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
333
334 -- | Adds signal names to the given FlatFunction
335 nameFlatFunction ::
336   FlatFunction
337   -> FlatFunction
338
339 nameFlatFunction flatfunc =
340   -- Name the signals
341   let 
342     s = flat_sigs flatfunc
343     s' = map nameSignal s in
344   flatfunc { flat_sigs = s' }
345   where
346     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
347     nameSignal (id, info) =
348       let hints = nameHints info in
349       let parts = ("sig" : hints) ++ [show id] in
350       let name = concat $ List.intersperse "_" parts in
351       (id, info {sigName = Just name})
352
353 -- | Splits a tuple type into a list of element types, or Nothing if the type
354 --   is not a tuple type.
355 splitTupleType ::
356   Type              -- ^ The type to split
357   -> Maybe [Type]   -- ^ The tuples element types
358
359 splitTupleType ty =
360   case Type.splitTyConApp_maybe ty of
361     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
362       then
363         Just args
364       else
365         Nothing
366     Nothing -> Nothing
367
368 -- | A consise representation of a (set of) ports on a builtin function
369 type PortMap = HsValueMap (String, AST.TypeMark)
370 -- | A consise representation of a builtin function
371 data BuiltIn = BuiltIn String [PortMap] PortMap
372
373 -- | Map a port specification of a builtin function to a VHDL Signal to put in
374 --   a VHDLSignalMap
375 toVHDLSignalMap :: HsValueMap (String, AST.TypeMark) -> VHDLSignalMap
376 toVHDLSignalMap = fmap (\(name, ty) -> Just (VHDL.mkVHDLId name, ty))
377
378 -- | Translate a concise representation of a builtin function to something
379 --   that can be put into FuncMap directly.
380 addBuiltIn :: BuiltIn -> VHDLState ()
381 addBuiltIn (BuiltIn name args res) = do
382     addFunc hsfunc
383     setEntity hsfunc entity
384   where
385     hsfunc = HsFunction name (map useAsPort args) (useAsPort res)
386     entity = Entity (VHDL.mkVHDLId name) (map toVHDLSignalMap args) (toVHDLSignalMap res) Nothing Nothing
387
388 builtin_funcs = 
389   [ 
390     BuiltIn "hwxor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
391     BuiltIn "hwand" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
392     BuiltIn "hwor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
393     BuiltIn "hwnot" [(Single ("a", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty))
394   ]
395
396 -- vim: set ts=8 sw=2 sts=2 expandtab: