Remove createEntity from the VHDLState monad.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import GHC hiding (loadModule, sigName)
5 import CoreSyn
6 import qualified CoreUtils
7 import qualified Var
8 import qualified Type
9 import qualified TyCon
10 import qualified DataCon
11 import qualified Maybe
12 import qualified Module
13 import qualified Control.Monad.State as State
14 import qualified Data.Foldable as Foldable
15 import Name
16 import qualified Data.Map as Map
17 import Data.Generics
18 import NameEnv ( lookupNameEnv )
19 import qualified HscTypes
20 import HscTypes ( cm_binds, cm_types )
21 import MonadUtils ( liftIO )
22 import Outputable ( showSDoc, ppr )
23 import GHC.Paths ( libdir )
24 import DynFlags ( defaultDynFlags )
25 import List ( find )
26 import qualified List
27 import qualified Monad
28
29 -- The following modules come from the ForSyDe project. They are really
30 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
31 -- ForSyDe to get access to these modules.
32 import qualified ForSyDe.Backend.VHDL.AST as AST
33 import qualified ForSyDe.Backend.VHDL.Ppr
34 import qualified ForSyDe.Backend.VHDL.FileIO
35 import qualified ForSyDe.Backend.Ppr
36 -- This is needed for rendering the pretty printed VHDL
37 import Text.PrettyPrint.HughesPJ (render)
38
39 import TranslatorTypes
40 import HsValueMap
41 import Pretty
42 import Flatten
43 import FlattenTypes
44 import VHDLTypes
45 import qualified VHDL
46
47 main = do
48   makeVHDL "Alu.hs" "register_bank" True
49
50 makeVHDL :: String -> String -> Bool -> IO ()
51 makeVHDL filename name stateful = do
52   -- Load the module
53   core <- loadModule filename
54   -- Translate to VHDL
55   vhdl <- moduleToVHDL core [(name, stateful)]
56   -- Write VHDL to file
57   let dir = "../vhdl/vhdl/" ++ name ++ "/"
58   mapM (writeVHDL dir) vhdl
59   return ()
60
61 -- | Show the core structure of the given binds in the given file.
62 listBind :: String -> String -> IO ()
63 listBind filename name = do
64   core <- loadModule filename
65   let binds = findBinds core [name]
66   putStr "\n"
67   putStr $ prettyShow binds
68   putStr "\n\n"
69   putStr $ showSDoc $ ppr binds
70   putStr "\n\n"
71
72 -- | Translate the binds with the given names from the given core module to
73 --   VHDL. The Bool in the tuple makes the function stateful (True) or
74 --   stateless (False).
75 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [AST.DesignFile]
76 moduleToVHDL core list = do
77   let (names, statefuls) = unzip list
78   --liftIO $ putStr $ prettyShow (cm_binds core)
79   let binds = findBinds core names
80   --putStr $ prettyShow binds
81   -- Turn bind into VHDL
82   let (vhdl, sess) = State.runState (mkVHDL binds statefuls) (VHDLSession core 0 Map.empty)
83   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr) vhdl
84   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
85   return vhdl
86
87   where
88     -- Turns the given bind into VHDL
89     mkVHDL binds statefuls = do
90       -- Add the builtin functions
91       mapM addBuiltIn builtin_funcs
92       -- Create entities and architectures for them
93       Monad.zipWithM processBind statefuls binds
94       modFuncs nameFlatFunction
95       modFuncMap $ Map.mapWithKey (\hsfunc fdata -> fdata {funcEntity = VHDL.createEntity hsfunc fdata})
96       modFuncs VHDL.createArchitecture
97       funcs <- getFuncs
98       return $ VHDL.getDesignFiles (map snd funcs)
99
100 -- | Write the given design file to a file inside the given dir
101 --   The first library unit in the designfile must be an entity, whose name
102 --   will be used as a filename.
103 writeVHDL :: String -> AST.DesignFile -> IO ()
104 writeVHDL dir vhdl = do
105   -- Create the dir if needed
106   exists <- Directory.doesDirectoryExist dir
107   Monad.unless exists $ Directory.createDirectory dir
108   -- Find the filename
109   let AST.DesignFile _ (u:us) = vhdl
110   let AST.LUEntity (AST.EntityDec id _) = u
111   let fname = dir ++ AST.fromVHDLId id ++ ".vhdl"
112   -- Write the file
113   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
114
115 -- | Loads the given file and turns it into a core module.
116 loadModule :: String -> IO HscTypes.CoreModule
117 loadModule filename =
118   defaultErrorHandler defaultDynFlags $ do
119     runGhc (Just libdir) $ do
120       dflags <- getSessionDynFlags
121       setSessionDynFlags dflags
122       --target <- guessTarget "adder.hs" Nothing
123       --liftIO (print (showSDoc (ppr (target))))
124       --liftIO $ printTarget target
125       --setTargets [target]
126       --load LoadAllTargets
127       --core <- GHC.compileToCoreSimplified "Adders.hs"
128       core <- GHC.compileToCoreSimplified filename
129       return core
130
131 -- | Extracts the named binds from the given module.
132 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
133 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
134
135 -- | Extract a named bind from the given list of binds
136 findBind :: [CoreBind] -> String -> Maybe CoreBind
137 findBind binds lookfor =
138   -- This ignores Recs and compares the name of the bind with lookfor,
139   -- disregarding any namespaces in OccName and extra attributes in Name and
140   -- Var.
141   find (\b -> case b of 
142     Rec l -> False
143     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
144   ) binds
145
146 -- | Processes the given bind as a top level bind.
147 processBind ::
148   Bool                       -- ^ Should this be stateful function?
149   -> CoreBind                -- ^ The bind to process
150   -> VHDLState ()
151
152 processBind _ (Rec _) = error "Recursive binders not supported"
153 processBind stateful bind@(NonRec var expr) = do
154   -- Create the function signature
155   let ty = CoreUtils.exprType expr
156   let hsfunc = mkHsFunction var ty stateful
157   flattenBind hsfunc bind
158
159 -- | Flattens the given bind into the given signature and adds it to the
160 --   session. Then (recursively) finds any functions it uses and does the same
161 --   with them.
162 flattenBind ::
163   HsFunction                         -- The signature to flatten into
164   -> CoreBind                        -- The bind to flatten
165   -> VHDLState ()
166
167 flattenBind _ (Rec _) = error "Recursive binders not supported"
168
169 flattenBind hsfunc bind@(NonRec var expr) = do
170   -- Add the function to the session
171   addFunc hsfunc
172   -- Flatten the function
173   let flatfunc = flattenFunction hsfunc bind
174   -- Propagate state variables
175   let flatfunc' = propagateState hsfunc flatfunc
176   -- Store the flat function in the session
177   setFlatFunc hsfunc flatfunc'
178   -- Flatten any functions used
179   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
180   State.mapM resolvFunc used_hsfuncs
181   return ()
182
183 -- | Decide which incoming state variables will become state in the
184 --   given function, and which will be propagate to other applied
185 --   functions.
186 propagateState ::
187   HsFunction
188   -> FlatFunction
189   -> FlatFunction
190
191 propagateState hsfunc flatfunc =
192     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
193   where
194     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
195     states' = zip olds news
196     -- Find all signals used by all sigdefs
197     uses = concatMap sigDefUses (flat_defs flatfunc)
198     -- Find all signals that are used more than once (is there a
199     -- prettier way to do this?)
200     multiple_uses = uses List.\\ (List.nub uses)
201     -- Find the states whose "old state" signal is used only once
202     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
203     -- See if these single use states can be propagated
204     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
205     substate_sigs = concat substate_sigss
206     -- Mark any propagated state signals as SigSubState
207     sigs' = map 
208       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
209       (flat_sigs flatfunc)
210
211 -- | Propagate the state into a single function application.
212 propagateState' ::
213   [(SignalId, SignalId)]
214                       -- ^ TODO
215   -> SigDef           -- ^ The SigDef to process.
216   -> ([SignalId], SigDef) 
217                       -- ^ Any signal ids that should become substates,
218                       --   and the resulting application.
219
220 propagateState' states def =
221     if (is_FApp def) then
222       (our_old ++ our_new, def {appFunc = hsfunc'})
223     else
224       ([], def)
225   where
226     hsfunc = appFunc def
227     args = appArgs def
228     res = appRes def
229     our_states = filter our_state states
230     -- A state signal belongs in this function if the old state is
231     -- passed in, and the new state returned
232     our_state (old, new) =
233       any (old `Foldable.elem`) args
234       && new `Foldable.elem` res
235     (our_old, our_new) = unzip our_states
236     -- Mark the result
237     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
238     res' = fmap (mark_state (zip our_new [0..])) zipped_res
239     -- Mark the args
240     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
241     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
242     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
243
244     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
245     mark_state states (id, use) =
246       case lookup id states of
247         Nothing -> use
248         Just state_id -> State state_id
249
250 -- | Returns pairs of signals that should be mapped to state in this function.
251 getStateSignals ::
252   HsFunction                      -- | The function to look at
253   -> FlatFunction                 -- | The function to look at
254   -> [(SignalId, SignalId)]   
255         -- | TODO The state signals. The first is the state number, the second the
256         --   signal to assign the current state to, the last is the signal
257         --   that holds the new state.
258
259 getStateSignals hsfunc flatfunc =
260   [(old_id, new_id) 
261     | (old_num, old_id) <- args
262     , (new_num, new_id) <- res
263     , old_num == new_num]
264   where
265     sigs = flat_sigs flatfunc
266     -- Translate args and res to lists of (statenum, sigid)
267     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
268     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
269     
270 -- | Find the given function, flatten it and add it to the session. Then
271 --   (recursively) do the same for any functions used.
272 resolvFunc ::
273   HsFunction        -- | The function to look for
274   -> VHDLState ()
275
276 resolvFunc hsfunc = do
277   -- See if the function is already known
278   func <- getFunc hsfunc
279   case func of
280     -- Already known, do nothing
281     Just _ -> do
282       return ()
283     -- New function, resolve it
284     Nothing -> do
285       -- Get the current module
286       core <- getModule
287       -- Find the named function
288       let bind = findBind (cm_binds core) name
289       case bind of
290         Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
291         Just b  -> flattenBind hsfunc b
292   where
293     name = hsFuncName hsfunc
294
295 -- | Translate a top level function declaration to a HsFunction. i.e., which
296 --   interface will be provided by this function. This function essentially
297 --   defines the "calling convention" for hardware models.
298 mkHsFunction ::
299   Var.Var         -- ^ The function defined
300   -> Type         -- ^ The function type (including arguments!)
301   -> Bool         -- ^ Is this a stateful function?
302   -> HsFunction   -- ^ The resulting HsFunction
303
304 mkHsFunction f ty stateful=
305   HsFunction hsname hsargs hsres
306   where
307     hsname  = getOccString f
308     (arg_tys, res_ty) = Type.splitFunTys ty
309     (hsargs, hsres) = 
310       if stateful 
311       then
312         let
313           -- The last argument must be state
314           state_ty = last arg_tys
315           state    = useAsState (mkHsValueMap state_ty)
316           -- All but the last argument are inports
317           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
318           hsargs   = inports ++ [state]
319           hsres    = case splitTupleType res_ty of
320             -- Result type must be a two tuple (state, ports)
321             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
322               then
323                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
324               else
325                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
326             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
327         in
328           (hsargs, hsres)
329       else
330         -- Just use everything as a port
331         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
332
333 -- | Adds signal names to the given FlatFunction
334 nameFlatFunction ::
335   HsFunction
336   -> FuncData
337   -> VHDLState ()
338
339 nameFlatFunction hsfunc fdata =
340   let func = flatFunc fdata in
341   case func of
342     -- Skip (builtin) functions without a FlatFunction
343     Nothing -> do return ()
344     -- Name the signals in all other functions
345     Just flatfunc ->
346       let s = flat_sigs flatfunc in
347       let s' = map nameSignal s in
348       let flatfunc' = flatfunc { flat_sigs = s' } in
349       setFlatFunc hsfunc flatfunc'
350   where
351     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
352     nameSignal (id, info) =
353       let hints = nameHints info in
354       let parts = ("sig" : hints) ++ [show id] in
355       let name = concat $ List.intersperse "_" parts in
356       (id, info {sigName = Just name})
357
358 -- | Splits a tuple type into a list of element types, or Nothing if the type
359 --   is not a tuple type.
360 splitTupleType ::
361   Type              -- ^ The type to split
362   -> Maybe [Type]   -- ^ The tuples element types
363
364 splitTupleType ty =
365   case Type.splitTyConApp_maybe ty of
366     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
367       then
368         Just args
369       else
370         Nothing
371     Nothing -> Nothing
372
373 -- | A consise representation of a (set of) ports on a builtin function
374 type PortMap = HsValueMap (String, AST.TypeMark)
375 -- | A consise representation of a builtin function
376 data BuiltIn = BuiltIn String [PortMap] PortMap
377
378 -- | Map a port specification of a builtin function to a VHDL Signal to put in
379 --   a VHDLSignalMap
380 toVHDLSignalMap :: HsValueMap (String, AST.TypeMark) -> VHDLSignalMap
381 toVHDLSignalMap = fmap (\(name, ty) -> Just (VHDL.mkVHDLId name, ty))
382
383 -- | Translate a concise representation of a builtin function to something
384 --   that can be put into FuncMap directly.
385 addBuiltIn :: BuiltIn -> VHDLState ()
386 addBuiltIn (BuiltIn name args res) = do
387     addFunc hsfunc
388     setEntity hsfunc entity
389   where
390     hsfunc = HsFunction name (map useAsPort args) (useAsPort res)
391     entity = Entity (VHDL.mkVHDLId name) (map toVHDLSignalMap args) (toVHDLSignalMap res) Nothing Nothing
392
393 builtin_funcs = 
394   [ 
395     BuiltIn "hwxor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
396     BuiltIn "hwand" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
397     BuiltIn "hwor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
398     BuiltIn "hwnot" [(Single ("a", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty))
399   ]
400
401 -- vim: set ts=8 sw=2 sts=2 expandtab: