ad36bbcb950a28f292b7dfb9fde20f87013d7712
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified System.FilePath as FilePath
4 import qualified List
5 import Debug.Trace
6 import qualified Control.Arrow as Arrow
7 import GHC hiding (loadModule, sigName)
8 import CoreSyn
9 import qualified CoreUtils
10 import qualified Var
11 import qualified Type
12 import qualified TyCon
13 import qualified DataCon
14 import qualified HscMain
15 import qualified SrcLoc
16 import qualified FastString
17 import qualified Maybe
18 import qualified Module
19 import qualified Data.Foldable as Foldable
20 import qualified Control.Monad.Trans.State as State
21 import Name
22 import qualified Data.Map as Map
23 import Data.Accessor
24 import Data.Generics
25 import NameEnv ( lookupNameEnv )
26 import qualified HscTypes
27 import HscTypes ( cm_binds, cm_types )
28 import MonadUtils ( liftIO )
29 import Outputable ( showSDoc, ppr )
30 import GHC.Paths ( libdir )
31 import DynFlags ( defaultDynFlags )
32 import qualified UniqSupply
33 import List ( find )
34 import qualified List
35 import qualified Monad
36
37 -- The following modules come from the ForSyDe project. They are really
38 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
39 -- ForSyDe to get access to these modules.
40 import qualified ForSyDe.Backend.VHDL.AST as AST
41 import qualified ForSyDe.Backend.VHDL.Ppr
42 import qualified ForSyDe.Backend.VHDL.FileIO
43 import qualified ForSyDe.Backend.Ppr
44 -- This is needed for rendering the pretty printed VHDL
45 import Text.PrettyPrint.HughesPJ (render)
46
47 import TranslatorTypes
48 import HsValueMap
49 import Pretty
50 import Normalize
51 import Flatten
52 import FlattenTypes
53 import VHDLTypes
54 import qualified VHDL
55
56 makeVHDL :: String -> String -> Bool -> IO ()
57 makeVHDL filename name stateful = do
58   -- Load the module
59   core <- loadModule filename
60   -- Translate to VHDL
61   vhdl <- moduleToVHDL core [(name, stateful)]
62   -- Write VHDL to file
63   let dir = "./vhdl/" ++ name ++ "/"
64   prepareDir dir
65   mapM (writeVHDL dir) vhdl
66   return ()
67
68 -- | Show the core structure of the given binds in the given file.
69 listBind :: String -> String -> IO ()
70 listBind filename name = do
71   core <- loadModule filename
72   let [(b, expr)] = findBinds core [name]
73   putStr "\n"
74   putStr $ prettyShow expr
75   putStr "\n\n"
76   putStr $ showSDoc $ ppr expr
77   putStr "\n\n"
78   putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
79   putStr "\n\n"
80
81 -- | Translate the binds with the given names from the given core module to
82 --   VHDL. The Bool in the tuple makes the function stateful (True) or
83 --   stateless (False).
84 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
85 moduleToVHDL core list = do
86   let (names, statefuls) = unzip list
87   let binds = map fst $ findBinds core names
88   -- Generate a UniqSupply
89   -- Running 
90   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
91   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
92   -- unique supply anywhere.
93   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
94   -- Turn bind into VHDL
95   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
96   let normalized_bindings = normalizeModule uniqSupply all_bindings binds statefuls
97   let vhdl = VHDL.createDesignFiles normalized_bindings
98   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr . snd) vhdl
99   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
100   return vhdl
101   where
102
103 -- | Prepares the directory for writing VHDL files. This means creating the
104 --   dir if it does not exist and removing all existing .vhdl files from it.
105 prepareDir :: String -> IO()
106 prepareDir dir = do
107   -- Create the dir if needed
108   exists <- Directory.doesDirectoryExist dir
109   Monad.unless exists $ Directory.createDirectory dir
110   -- Find all .vhdl files in the directory
111   files <- Directory.getDirectoryContents dir
112   let to_remove = filter ((==".vhdl") . FilePath.takeExtension) files
113   -- Prepend the dirname to the filenames
114   let abs_to_remove = map (FilePath.combine dir) to_remove
115   -- Remove the files
116   mapM_ Directory.removeFile abs_to_remove
117
118 -- | Write the given design file to a file with the given name inside the
119 --   given dir
120 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
121 writeVHDL dir (name, vhdl) = do
122   -- Find the filename
123   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
124   -- Write the file
125   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
126
127 -- | Loads the given file and turns it into a core module.
128 loadModule :: String -> IO HscTypes.CoreModule
129 loadModule filename =
130   defaultErrorHandler defaultDynFlags $ do
131     runGhc (Just libdir) $ do
132       dflags <- getSessionDynFlags
133       setSessionDynFlags dflags
134       --target <- guessTarget "adder.hs" Nothing
135       --liftIO (print (showSDoc (ppr (target))))
136       --liftIO $ printTarget target
137       --setTargets [target]
138       --load LoadAllTargets
139       --core <- GHC.compileToCoreSimplified "Adders.hs"
140       core <- GHC.compileToCoreModule filename
141       return core
142
143 -- | Extracts the named binds from the given module.
144 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
145 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
146
147 -- | Extract a named bind from the given list of binds
148 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
149 findBind binds lookfor =
150   -- This ignores Recs and compares the name of the bind with lookfor,
151   -- disregarding any namespaces in OccName and extra attributes in Name and
152   -- Var.
153   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
154
155 -- | Flattens the given bind into the given signature and adds it to the
156 --   session. Then (recursively) finds any functions it uses and does the same
157 --   with them.
158 flattenBind ::
159   HsFunction                         -- The signature to flatten into
160   -> (CoreBndr, CoreExpr)            -- The bind to flatten
161   -> TranslatorState ()
162
163 flattenBind hsfunc bind@(var, expr) = do
164   -- Flatten the function
165   let flatfunc = flattenFunction hsfunc bind
166   -- Propagate state variables
167   let flatfunc' = propagateState hsfunc flatfunc
168   -- Store the flat function in the session
169   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
170   -- Flatten any functions used
171   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
172   mapM_ resolvFunc used_hsfuncs
173
174 -- | Decide which incoming state variables will become state in the
175 --   given function, and which will be propagate to other applied
176 --   functions.
177 propagateState ::
178   HsFunction
179   -> FlatFunction
180   -> FlatFunction
181
182 propagateState hsfunc flatfunc =
183     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
184   where
185     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
186     states' = zip olds news
187     -- Find all signals used by all sigdefs
188     uses = concatMap sigDefUses (flat_defs flatfunc)
189     -- Find all signals that are used more than once (is there a
190     -- prettier way to do this?)
191     multiple_uses = uses List.\\ (List.nub uses)
192     -- Find the states whose "old state" signal is used only once
193     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
194     -- See if these single use states can be propagated
195     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
196     substate_sigs = concat substate_sigss
197     -- Mark any propagated state signals as SigSubState
198     sigs' = map 
199       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
200       (flat_sigs flatfunc)
201
202 -- | Propagate the state into a single function application.
203 propagateState' ::
204   [(SignalId, SignalId)]
205                       -- ^ TODO
206   -> SigDef           -- ^ The SigDef to process.
207   -> ([SignalId], SigDef) 
208                       -- ^ Any signal ids that should become substates,
209                       --   and the resulting application.
210
211 propagateState' states def =
212     if (is_FApp def) then
213       (our_old ++ our_new, def {appFunc = hsfunc'})
214     else
215       ([], def)
216   where
217     hsfunc = appFunc def
218     args = appArgs def
219     res = appRes def
220     our_states = filter our_state states
221     -- A state signal belongs in this function if the old state is
222     -- passed in, and the new state returned
223     our_state (old, new) =
224       any (old `Foldable.elem`) args
225       && new `Foldable.elem` res
226     (our_old, our_new) = unzip our_states
227     -- Mark the result
228     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
229     res' = fmap (mark_state (zip our_new [0..])) zipped_res
230     -- Mark the args
231     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
232     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
233     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
234
235     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
236     mark_state states (id, use) =
237       case lookup id states of
238         Nothing -> use
239         Just state_id -> State state_id
240
241 -- | Returns pairs of signals that should be mapped to state in this function.
242 getStateSignals ::
243   HsFunction                      -- | The function to look at
244   -> FlatFunction                 -- | The function to look at
245   -> [(SignalId, SignalId)]   
246         -- | TODO The state signals. The first is the state number, the second the
247         --   signal to assign the current state to, the last is the signal
248         --   that holds the new state.
249
250 getStateSignals hsfunc flatfunc =
251   [(old_id, new_id) 
252     | (old_num, old_id) <- args
253     , (new_num, new_id) <- res
254     , old_num == new_num]
255   where
256     sigs = flat_sigs flatfunc
257     -- Translate args and res to lists of (statenum, sigid)
258     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
259     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
260     
261 -- | Find the given function, flatten it and add it to the session. Then
262 --   (recursively) do the same for any functions used.
263 resolvFunc ::
264   HsFunction        -- | The function to look for
265   -> TranslatorState ()
266
267 resolvFunc hsfunc = do
268   flatfuncmap <- getA tsFlatFuncs
269   -- Don't do anything if there is already a flat function for this hsfunc or
270   -- when it is a builtin function.
271   Monad.unless (Map.member hsfunc flatfuncmap) $ do
272   -- Not working with new builtins -- Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
273   -- New function, resolve it
274   core <- getA tsCoreModule
275   -- Find the named function
276   let name = (hsFuncName hsfunc)
277   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
278   case bind of
279     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
280     Just b  -> flattenBind hsfunc b
281
282 -- | Translate a top level function declaration to a HsFunction. i.e., which
283 --   interface will be provided by this function. This function essentially
284 --   defines the "calling convention" for hardware models.
285 mkHsFunction ::
286   Var.Var         -- ^ The function defined
287   -> Type         -- ^ The function type (including arguments!)
288   -> Bool         -- ^ Is this a stateful function?
289   -> HsFunction   -- ^ The resulting HsFunction
290
291 mkHsFunction f ty stateful=
292   HsFunction hsname hsargs hsres
293   where
294     hsname  = getOccString f
295     (arg_tys, res_ty) = Type.splitFunTys ty
296     (hsargs, hsres) = 
297       if stateful 
298       then
299         let
300           -- The last argument must be state
301           state_ty = last arg_tys
302           state    = useAsState (mkHsValueMap state_ty)
303           -- All but the last argument are inports
304           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
305           hsargs   = inports ++ [state]
306           hsres    = case splitTupleType res_ty of
307             -- Result type must be a two tuple (state, ports)
308             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
309               then
310                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
311               else
312                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
313             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
314         in
315           (hsargs, hsres)
316       else
317         -- Just use everything as a port
318         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
319
320 -- | Adds signal names to the given FlatFunction
321 nameFlatFunction ::
322   FlatFunction
323   -> FlatFunction
324
325 nameFlatFunction flatfunc =
326   -- Name the signals
327   let 
328     s = flat_sigs flatfunc
329     s' = map nameSignal s in
330   flatfunc { flat_sigs = s' }
331   where
332     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
333     nameSignal (id, info) =
334       let hints = nameHints info in
335       let parts = ("sig" : hints) ++ [show id] in
336       let name = concat $ List.intersperse "_" parts in
337       (id, info {sigName = Just name})
338
339 -- | Splits a tuple type into a list of element types, or Nothing if the type
340 --   is not a tuple type.
341 splitTupleType ::
342   Type              -- ^ The type to split
343   -> Maybe [Type]   -- ^ The tuples element types
344
345 splitTupleType ty =
346   case Type.splitTyConApp_maybe ty of
347     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
348       then
349         Just args
350       else
351         Nothing
352     Nothing -> Nothing
353
354 -- vim: set ts=8 sw=2 sts=2 expandtab: