6b0cdd1e06c298e4a8261cb707274e1e9bd3ab7d
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import Debug.Trace
5 import qualified Control.Arrow as Arrow
6 import GHC hiding (loadModule, sigName)
7 import CoreSyn
8 import qualified CoreUtils
9 import qualified Var
10 import qualified Type
11 import qualified TyCon
12 import qualified DataCon
13 import qualified HscMain
14 import qualified SrcLoc
15 import qualified FastString
16 import qualified Maybe
17 import qualified Module
18 import qualified Data.Foldable as Foldable
19 import qualified Control.Monad.Trans.State as State
20 import Name
21 import qualified Data.Map as Map
22 import Data.Accessor
23 import Data.Generics
24 import NameEnv ( lookupNameEnv )
25 import qualified HscTypes
26 import HscTypes ( cm_binds, cm_types )
27 import MonadUtils ( liftIO )
28 import Outputable ( showSDoc, ppr )
29 import GHC.Paths ( libdir )
30 import DynFlags ( defaultDynFlags )
31 import qualified UniqSupply
32 import List ( find )
33 import qualified List
34 import qualified Monad
35
36 -- The following modules come from the ForSyDe project. They are really
37 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
38 -- ForSyDe to get access to these modules.
39 import qualified ForSyDe.Backend.VHDL.AST as AST
40 import qualified ForSyDe.Backend.VHDL.Ppr
41 import qualified ForSyDe.Backend.VHDL.FileIO
42 import qualified ForSyDe.Backend.Ppr
43 -- This is needed for rendering the pretty printed VHDL
44 import Text.PrettyPrint.HughesPJ (render)
45
46 import TranslatorTypes
47 import HsValueMap
48 import Pretty
49 import Normalize
50 import Flatten
51 import FlattenTypes
52 import VHDLTypes
53 import qualified VHDL
54
55 makeVHDL :: String -> String -> Bool -> IO ()
56 makeVHDL filename name stateful = do
57   -- Load the module
58   core <- loadModule filename
59   -- Translate to VHDL
60   vhdl <- moduleToVHDL core [(name, stateful)]
61   -- Write VHDL to file
62   let dir = "./vhdl/" ++ name ++ "/"
63   mapM (writeVHDL dir) vhdl
64   return ()
65
66 -- | Show the core structure of the given binds in the given file.
67 listBind :: String -> String -> IO ()
68 listBind filename name = do
69   core <- loadModule filename
70   let [(b, expr)] = findBinds core [name]
71   putStr "\n"
72   putStr $ prettyShow expr
73   putStr "\n\n"
74   putStr $ showSDoc $ ppr expr
75   putStr "\n\n"
76   putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
77   putStr "\n\n"
78
79 -- | Translate the binds with the given names from the given core module to
80 --   VHDL. The Bool in the tuple makes the function stateful (True) or
81 --   stateless (False).
82 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
83 moduleToVHDL core list = do
84   let (names, statefuls) = unzip list
85   let binds = findBinds core names
86   -- Generate a UniqSupply
87   -- Running 
88   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
89   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
90   -- unique supply anywhere.
91   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
92   -- Turn bind into VHDL
93   let (vhdl, sess) = State.runState (mkVHDL uniqSupply binds statefuls) (TranslatorSession core 0 Map.empty)
94   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr . snd) vhdl
95   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
96   return vhdl
97   where
98     -- Turns the given bind into VHDL
99     mkVHDL :: UniqSupply.UniqSupply -> [(CoreBndr, CoreExpr)] -> [Bool] -> TranslatorState [(AST.VHDLId, AST.DesignFile)]
100     mkVHDL uniqSupply binds statefuls = do
101       let binds'' = map (Arrow.second $ normalize uniqSupply) binds
102       let binds' = trace ("Before:\n\n" ++ showSDoc ( ppr binds ) ++ "\n\nAfter:\n\n" ++ showSDoc ( ppr binds'')) binds''
103       -- Add the builtin functions
104       --mapM addBuiltIn builtin_funcs
105       -- Create entities and architectures for them
106       --Monad.zipWithM processBind statefuls binds
107       --modA tsFlatFuncs (Map.map nameFlatFunction)
108       --flatfuncs <- getA tsFlatFuncs
109       return $ VHDL.createDesignFiles binds'
110
111 -- | Write the given design file to a file with the given name inside the
112 --   given dir
113 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
114 writeVHDL dir (name, vhdl) = do
115   -- Create the dir if needed
116   exists <- Directory.doesDirectoryExist dir
117   Monad.unless exists $ Directory.createDirectory dir
118   -- Find the filename
119   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
120   -- Write the file
121   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
122
123 -- | Loads the given file and turns it into a core module.
124 loadModule :: String -> IO HscTypes.CoreModule
125 loadModule filename =
126   defaultErrorHandler defaultDynFlags $ do
127     runGhc (Just libdir) $ do
128       dflags <- getSessionDynFlags
129       setSessionDynFlags dflags
130       --target <- guessTarget "adder.hs" Nothing
131       --liftIO (print (showSDoc (ppr (target))))
132       --liftIO $ printTarget target
133       --setTargets [target]
134       --load LoadAllTargets
135       --core <- GHC.compileToCoreSimplified "Adders.hs"
136       core <- GHC.compileToCoreModule filename
137       return core
138
139 -- | Extracts the named binds from the given module.
140 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
141 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
142
143 -- | Extract a named bind from the given list of binds
144 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
145 findBind binds lookfor =
146   -- This ignores Recs and compares the name of the bind with lookfor,
147   -- disregarding any namespaces in OccName and extra attributes in Name and
148   -- Var.
149   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
150
151 -- | Processes the given bind as a top level bind.
152 processBind ::
153   Bool                       -- ^ Should this be stateful function?
154   -> (CoreBndr, CoreExpr)    -- ^ The bind to process
155   -> TranslatorState ()
156
157 processBind stateful bind@(var, expr) = do
158   -- Create the function signature
159   let ty = CoreUtils.exprType expr
160   let hsfunc = mkHsFunction var ty stateful
161   flattenBind hsfunc bind
162
163 -- | Flattens the given bind into the given signature and adds it to the
164 --   session. Then (recursively) finds any functions it uses and does the same
165 --   with them.
166 flattenBind ::
167   HsFunction                         -- The signature to flatten into
168   -> (CoreBndr, CoreExpr)            -- The bind to flatten
169   -> TranslatorState ()
170
171 flattenBind hsfunc bind@(var, expr) = do
172   -- Flatten the function
173   let flatfunc = flattenFunction hsfunc bind
174   -- Propagate state variables
175   let flatfunc' = propagateState hsfunc flatfunc
176   -- Store the flat function in the session
177   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
178   -- Flatten any functions used
179   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
180   mapM_ resolvFunc used_hsfuncs
181
182 -- | Decide which incoming state variables will become state in the
183 --   given function, and which will be propagate to other applied
184 --   functions.
185 propagateState ::
186   HsFunction
187   -> FlatFunction
188   -> FlatFunction
189
190 propagateState hsfunc flatfunc =
191     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
192   where
193     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
194     states' = zip olds news
195     -- Find all signals used by all sigdefs
196     uses = concatMap sigDefUses (flat_defs flatfunc)
197     -- Find all signals that are used more than once (is there a
198     -- prettier way to do this?)
199     multiple_uses = uses List.\\ (List.nub uses)
200     -- Find the states whose "old state" signal is used only once
201     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
202     -- See if these single use states can be propagated
203     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
204     substate_sigs = concat substate_sigss
205     -- Mark any propagated state signals as SigSubState
206     sigs' = map 
207       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
208       (flat_sigs flatfunc)
209
210 -- | Propagate the state into a single function application.
211 propagateState' ::
212   [(SignalId, SignalId)]
213                       -- ^ TODO
214   -> SigDef           -- ^ The SigDef to process.
215   -> ([SignalId], SigDef) 
216                       -- ^ Any signal ids that should become substates,
217                       --   and the resulting application.
218
219 propagateState' states def =
220     if (is_FApp def) then
221       (our_old ++ our_new, def {appFunc = hsfunc'})
222     else
223       ([], def)
224   where
225     hsfunc = appFunc def
226     args = appArgs def
227     res = appRes def
228     our_states = filter our_state states
229     -- A state signal belongs in this function if the old state is
230     -- passed in, and the new state returned
231     our_state (old, new) =
232       any (old `Foldable.elem`) args
233       && new `Foldable.elem` res
234     (our_old, our_new) = unzip our_states
235     -- Mark the result
236     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
237     res' = fmap (mark_state (zip our_new [0..])) zipped_res
238     -- Mark the args
239     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
240     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
241     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
242
243     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
244     mark_state states (id, use) =
245       case lookup id states of
246         Nothing -> use
247         Just state_id -> State state_id
248
249 -- | Returns pairs of signals that should be mapped to state in this function.
250 getStateSignals ::
251   HsFunction                      -- | The function to look at
252   -> FlatFunction                 -- | The function to look at
253   -> [(SignalId, SignalId)]   
254         -- | TODO The state signals. The first is the state number, the second the
255         --   signal to assign the current state to, the last is the signal
256         --   that holds the new state.
257
258 getStateSignals hsfunc flatfunc =
259   [(old_id, new_id) 
260     | (old_num, old_id) <- args
261     , (new_num, new_id) <- res
262     , old_num == new_num]
263   where
264     sigs = flat_sigs flatfunc
265     -- Translate args and res to lists of (statenum, sigid)
266     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
267     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
268     
269 -- | Find the given function, flatten it and add it to the session. Then
270 --   (recursively) do the same for any functions used.
271 resolvFunc ::
272   HsFunction        -- | The function to look for
273   -> TranslatorState ()
274
275 resolvFunc hsfunc = do
276   flatfuncmap <- getA tsFlatFuncs
277   -- Don't do anything if there is already a flat function for this hsfunc or
278   -- when it is a builtin function.
279   Monad.unless (Map.member hsfunc flatfuncmap) $ do
280   -- Not working with new builtins -- Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
281   -- New function, resolve it
282   core <- getA tsCoreModule
283   -- Find the named function
284   let name = (hsFuncName hsfunc)
285   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
286   case bind of
287     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
288     Just b  -> flattenBind hsfunc b
289
290 -- | Translate a top level function declaration to a HsFunction. i.e., which
291 --   interface will be provided by this function. This function essentially
292 --   defines the "calling convention" for hardware models.
293 mkHsFunction ::
294   Var.Var         -- ^ The function defined
295   -> Type         -- ^ The function type (including arguments!)
296   -> Bool         -- ^ Is this a stateful function?
297   -> HsFunction   -- ^ The resulting HsFunction
298
299 mkHsFunction f ty stateful=
300   HsFunction hsname hsargs hsres
301   where
302     hsname  = getOccString f
303     (arg_tys, res_ty) = Type.splitFunTys ty
304     (hsargs, hsres) = 
305       if stateful 
306       then
307         let
308           -- The last argument must be state
309           state_ty = last arg_tys
310           state    = useAsState (mkHsValueMap state_ty)
311           -- All but the last argument are inports
312           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
313           hsargs   = inports ++ [state]
314           hsres    = case splitTupleType res_ty of
315             -- Result type must be a two tuple (state, ports)
316             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
317               then
318                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
319               else
320                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
321             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
322         in
323           (hsargs, hsres)
324       else
325         -- Just use everything as a port
326         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
327
328 -- | Adds signal names to the given FlatFunction
329 nameFlatFunction ::
330   FlatFunction
331   -> FlatFunction
332
333 nameFlatFunction flatfunc =
334   -- Name the signals
335   let 
336     s = flat_sigs flatfunc
337     s' = map nameSignal s in
338   flatfunc { flat_sigs = s' }
339   where
340     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
341     nameSignal (id, info) =
342       let hints = nameHints info in
343       let parts = ("sig" : hints) ++ [show id] in
344       let name = concat $ List.intersperse "_" parts in
345       (id, info {sigName = Just name})
346
347 -- | Splits a tuple type into a list of element types, or Nothing if the type
348 --   is not a tuple type.
349 splitTupleType ::
350   Type              -- ^ The type to split
351   -> Maybe [Type]   -- ^ The tuples element types
352
353 splitTupleType ty =
354   case Type.splitTyConApp_maybe ty of
355     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
356       then
357         Just args
358       else
359         Nothing
360     Nothing -> Nothing
361
362 -- vim: set ts=8 sw=2 sts=2 expandtab: