85f790a6a979baac92c54a632f4798c86f2a8dfb
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified System.FilePath as FilePath
4 import qualified List
5 import Debug.Trace
6 import qualified Control.Arrow as Arrow
7 import GHC hiding (loadModule, sigName)
8 import CoreSyn
9 import qualified CoreUtils
10 import qualified Var
11 import qualified Type
12 import qualified TyCon
13 import qualified DataCon
14 import qualified HscMain
15 import qualified SrcLoc
16 import qualified FastString
17 import qualified Maybe
18 import qualified Module
19 import qualified Data.Foldable as Foldable
20 import qualified Control.Monad.Trans.State as State
21 import Name
22 import qualified Data.Map as Map
23 import Data.Accessor
24 import Data.Generics
25 import NameEnv ( lookupNameEnv )
26 import qualified HscTypes
27 import HscTypes ( cm_binds, cm_types )
28 import MonadUtils ( liftIO )
29 import Outputable ( showSDoc, ppr, showSDocDebug )
30 import GHC.Paths ( libdir )
31 import DynFlags ( defaultDynFlags )
32 import qualified UniqSupply
33 import List ( find )
34 import qualified List
35 import qualified Monad
36
37 -- The following modules come from the ForSyDe project. They are really
38 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
39 -- ForSyDe to get access to these modules.
40 import qualified ForSyDe.Backend.VHDL.AST as AST
41 import qualified ForSyDe.Backend.VHDL.Ppr
42 import qualified ForSyDe.Backend.VHDL.FileIO
43 import qualified ForSyDe.Backend.Ppr
44 -- This is needed for rendering the pretty printed VHDL
45 import Text.PrettyPrint.HughesPJ (render)
46
47 import TranslatorTypes
48 import HsValueMap
49 import Pretty
50 import Normalize
51 import Flatten
52 import FlattenTypes
53 import VHDLTypes
54 import qualified VHDL
55
56 makeVHDL :: String -> String -> Bool -> IO ()
57 makeVHDL filename name stateful = do
58   -- Load the module
59   core <- loadModule filename
60   -- Translate to VHDL
61   vhdl <- moduleToVHDL core [(name, stateful)]
62   -- Write VHDL to file
63   let dir = "./vhdl/" ++ name ++ "/"
64   prepareDir dir
65   mapM (writeVHDL dir) vhdl
66   return ()
67
68 listBindings :: String -> IO [()]
69 listBindings filename = do
70   core <- loadModule filename
71   let binds = CoreSyn.flattenBinds $ cm_binds core
72   mapM (listBinding) binds
73
74 listBinding :: (CoreBndr, CoreExpr) -> IO ()
75 listBinding (b, e) = do
76   putStr "\nBinder: "
77   putStr $ show b
78   putStr "\nExpression: \n"
79   putStr $ prettyShow e
80   putStr "\n\n"
81   putStr $ showSDoc $ ppr e
82   putStr "\n\n"
83   putStr $ showSDoc $ ppr $ CoreUtils.exprType e
84   putStr "\n\n"
85   
86 -- | Show the core structure of the given binds in the given file.
87 listBind :: String -> String -> IO ()
88 listBind filename name = do
89   core <- loadModule filename
90   let [(b, expr)] = findBinds core [name]
91   putStr "\n"
92   putStr $ prettyShow expr
93   putStr "\n\n"
94   putStr $ showSDoc $ ppr expr
95   putStr "\n\n"
96   putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
97   putStr "\n\n"
98
99 -- | Translate the binds with the given names from the given core module to
100 --   VHDL. The Bool in the tuple makes the function stateful (True) or
101 --   stateless (False).
102 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
103 moduleToVHDL core list = do
104   let (names, statefuls) = unzip list
105   let binds = map fst $ findBinds core names
106   -- Generate a UniqSupply
107   -- Running 
108   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
109   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
110   -- unique supply anywhere.
111   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
112   -- Turn bind into VHDL
113   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
114   let normalized_bindings = normalizeModule uniqSupply all_bindings binds statefuls
115   let vhdl = VHDL.createDesignFiles normalized_bindings
116   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr . snd) vhdl
117   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
118   return vhdl
119   where
120
121 -- | Prepares the directory for writing VHDL files. This means creating the
122 --   dir if it does not exist and removing all existing .vhdl files from it.
123 prepareDir :: String -> IO()
124 prepareDir dir = do
125   -- Create the dir if needed
126   exists <- Directory.doesDirectoryExist dir
127   Monad.unless exists $ Directory.createDirectory dir
128   -- Find all .vhdl files in the directory
129   files <- Directory.getDirectoryContents dir
130   let to_remove = filter ((==".vhdl") . FilePath.takeExtension) files
131   -- Prepend the dirname to the filenames
132   let abs_to_remove = map (FilePath.combine dir) to_remove
133   -- Remove the files
134   mapM_ Directory.removeFile abs_to_remove
135
136 -- | Write the given design file to a file with the given name inside the
137 --   given dir
138 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
139 writeVHDL dir (name, vhdl) = do
140   -- Find the filename
141   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
142   -- Write the file
143   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
144
145 -- | Loads the given file and turns it into a core module.
146 loadModule :: String -> IO HscTypes.CoreModule
147 loadModule filename =
148   defaultErrorHandler defaultDynFlags $ do
149     runGhc (Just libdir) $ do
150       dflags <- getSessionDynFlags
151       setSessionDynFlags dflags
152       --target <- guessTarget "adder.hs" Nothing
153       --liftIO (print (showSDoc (ppr (target))))
154       --liftIO $ printTarget target
155       --setTargets [target]
156       --load LoadAllTargets
157       --core <- GHC.compileToCoreSimplified "Adders.hs"
158       core <- GHC.compileToCoreModule filename
159       return core
160
161 -- | Extracts the named binds from the given module.
162 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
163 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
164
165 -- | Extract a named bind from the given list of binds
166 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
167 findBind binds lookfor =
168   -- This ignores Recs and compares the name of the bind with lookfor,
169   -- disregarding any namespaces in OccName and extra attributes in Name and
170   -- Var.
171   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
172
173 -- | Flattens the given bind into the given signature and adds it to the
174 --   session. Then (recursively) finds any functions it uses and does the same
175 --   with them.
176 flattenBind ::
177   HsFunction                         -- The signature to flatten into
178   -> (CoreBndr, CoreExpr)            -- The bind to flatten
179   -> TranslatorState ()
180
181 flattenBind hsfunc bind@(var, expr) = do
182   -- Flatten the function
183   let flatfunc = flattenFunction hsfunc bind
184   -- Propagate state variables
185   let flatfunc' = propagateState hsfunc flatfunc
186   -- Store the flat function in the session
187   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
188   -- Flatten any functions used
189   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
190   mapM_ resolvFunc used_hsfuncs
191
192 -- | Decide which incoming state variables will become state in the
193 --   given function, and which will be propagate to other applied
194 --   functions.
195 propagateState ::
196   HsFunction
197   -> FlatFunction
198   -> FlatFunction
199
200 propagateState hsfunc flatfunc =
201     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
202   where
203     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
204     states' = zip olds news
205     -- Find all signals used by all sigdefs
206     uses = concatMap sigDefUses (flat_defs flatfunc)
207     -- Find all signals that are used more than once (is there a
208     -- prettier way to do this?)
209     multiple_uses = uses List.\\ (List.nub uses)
210     -- Find the states whose "old state" signal is used only once
211     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
212     -- See if these single use states can be propagated
213     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
214     substate_sigs = concat substate_sigss
215     -- Mark any propagated state signals as SigSubState
216     sigs' = map 
217       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
218       (flat_sigs flatfunc)
219
220 -- | Propagate the state into a single function application.
221 propagateState' ::
222   [(SignalId, SignalId)]
223                       -- ^ TODO
224   -> SigDef           -- ^ The SigDef to process.
225   -> ([SignalId], SigDef) 
226                       -- ^ Any signal ids that should become substates,
227                       --   and the resulting application.
228
229 propagateState' states def =
230     if (is_FApp def) then
231       (our_old ++ our_new, def {appFunc = hsfunc'})
232     else
233       ([], def)
234   where
235     hsfunc = appFunc def
236     args = appArgs def
237     res = appRes def
238     our_states = filter our_state states
239     -- A state signal belongs in this function if the old state is
240     -- passed in, and the new state returned
241     our_state (old, new) =
242       any (old `Foldable.elem`) args
243       && new `Foldable.elem` res
244     (our_old, our_new) = unzip our_states
245     -- Mark the result
246     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
247     res' = fmap (mark_state (zip our_new [0..])) zipped_res
248     -- Mark the args
249     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
250     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
251     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
252
253     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
254     mark_state states (id, use) =
255       case lookup id states of
256         Nothing -> use
257         Just state_id -> State state_id
258
259 -- | Returns pairs of signals that should be mapped to state in this function.
260 getStateSignals ::
261   HsFunction                      -- | The function to look at
262   -> FlatFunction                 -- | The function to look at
263   -> [(SignalId, SignalId)]   
264         -- | TODO The state signals. The first is the state number, the second the
265         --   signal to assign the current state to, the last is the signal
266         --   that holds the new state.
267
268 getStateSignals hsfunc flatfunc =
269   [(old_id, new_id) 
270     | (old_num, old_id) <- args
271     , (new_num, new_id) <- res
272     , old_num == new_num]
273   where
274     sigs = flat_sigs flatfunc
275     -- Translate args and res to lists of (statenum, sigid)
276     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
277     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
278     
279 -- | Find the given function, flatten it and add it to the session. Then
280 --   (recursively) do the same for any functions used.
281 resolvFunc ::
282   HsFunction        -- | The function to look for
283   -> TranslatorState ()
284
285 resolvFunc hsfunc = do
286   flatfuncmap <- getA tsFlatFuncs
287   -- Don't do anything if there is already a flat function for this hsfunc or
288   -- when it is a builtin function.
289   Monad.unless (Map.member hsfunc flatfuncmap) $ do
290   -- Not working with new builtins -- Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
291   -- New function, resolve it
292   core <- getA tsCoreModule
293   -- Find the named function
294   let name = (hsFuncName hsfunc)
295   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
296   case bind of
297     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
298     Just b  -> flattenBind hsfunc b
299
300 -- | Translate a top level function declaration to a HsFunction. i.e., which
301 --   interface will be provided by this function. This function essentially
302 --   defines the "calling convention" for hardware models.
303 mkHsFunction ::
304   Var.Var         -- ^ The function defined
305   -> Type         -- ^ The function type (including arguments!)
306   -> Bool         -- ^ Is this a stateful function?
307   -> HsFunction   -- ^ The resulting HsFunction
308
309 mkHsFunction f ty stateful=
310   HsFunction hsname hsargs hsres
311   where
312     hsname  = getOccString f
313     (arg_tys, res_ty) = Type.splitFunTys ty
314     (hsargs, hsres) = 
315       if stateful 
316       then
317         let
318           -- The last argument must be state
319           state_ty = last arg_tys
320           state    = useAsState (mkHsValueMap state_ty)
321           -- All but the last argument are inports
322           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
323           hsargs   = inports ++ [state]
324           hsres    = case splitTupleType res_ty of
325             -- Result type must be a two tuple (state, ports)
326             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
327               then
328                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
329               else
330                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
331             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
332         in
333           (hsargs, hsres)
334       else
335         -- Just use everything as a port
336         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
337
338 -- | Adds signal names to the given FlatFunction
339 nameFlatFunction ::
340   FlatFunction
341   -> FlatFunction
342
343 nameFlatFunction flatfunc =
344   -- Name the signals
345   let 
346     s = flat_sigs flatfunc
347     s' = map nameSignal s in
348   flatfunc { flat_sigs = s' }
349   where
350     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
351     nameSignal (id, info) =
352       let hints = nameHints info in
353       let parts = ("sig" : hints) ++ [show id] in
354       let name = concat $ List.intersperse "_" parts in
355       (id, info {sigName = Just name})
356
357 -- | Splits a tuple type into a list of element types, or Nothing if the type
358 --   is not a tuple type.
359 splitTupleType ::
360   Type              -- ^ The type to split
361   -> Maybe [Type]   -- ^ The tuples element types
362
363 splitTupleType ty =
364   case Type.splitTyConApp_maybe ty of
365     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
366       then
367         Just args
368       else
369         Nothing
370     Nothing -> Nothing
371
372 -- vim: set ts=8 sw=2 sts=2 expandtab: