Merge branch 'cλash' of http://git.stderr.nl/matthijs/projects/master-project
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import Debug.Trace
5 import qualified Control.Arrow as Arrow
6 import GHC hiding (loadModule, sigName)
7 import CoreSyn
8 import qualified CoreUtils
9 import qualified Var
10 import qualified Type
11 import qualified TyCon
12 import qualified DataCon
13 import qualified HscMain
14 import qualified SrcLoc
15 import qualified FastString
16 import qualified Maybe
17 import qualified Module
18 import qualified Data.Foldable as Foldable
19 import qualified Control.Monad.Trans.State as State
20 import Name
21 import qualified Data.Map as Map
22 import Data.Accessor
23 import Data.Generics
24 import NameEnv ( lookupNameEnv )
25 import qualified HscTypes
26 import HscTypes ( cm_binds, cm_types )
27 import MonadUtils ( liftIO )
28 import Outputable ( showSDoc, ppr )
29 import GHC.Paths ( libdir )
30 import DynFlags ( defaultDynFlags )
31 import qualified UniqSupply
32 import List ( find )
33 import qualified List
34 import qualified Monad
35
36 -- The following modules come from the ForSyDe project. They are really
37 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
38 -- ForSyDe to get access to these modules.
39 import qualified ForSyDe.Backend.VHDL.AST as AST
40 import qualified ForSyDe.Backend.VHDL.Ppr
41 import qualified ForSyDe.Backend.VHDL.FileIO
42 import qualified ForSyDe.Backend.Ppr
43 -- This is needed for rendering the pretty printed VHDL
44 import Text.PrettyPrint.HughesPJ (render)
45
46 import TranslatorTypes
47 import HsValueMap
48 import Pretty
49 import Normalize
50 import Flatten
51 import FlattenTypes
52 import VHDLTypes
53 import qualified VHDL
54
55 main = do
56   makeVHDL "Adders.hs" "highordtest2" True
57
58 makeVHDL :: String -> String -> Bool -> IO ()
59 makeVHDL filename name stateful = do
60   -- Load the module
61   core <- loadModule filename
62   -- Translate to VHDL
63   vhdl <- moduleToVHDL core [(name, stateful)]
64   -- Write VHDL to file
65   let dir = "./vhdl/" ++ name ++ "/"
66   mapM (writeVHDL dir) vhdl
67   return ()
68
69 -- | Show the core structure of the given binds in the given file.
70 listBind :: String -> String -> IO ()
71 listBind filename name = do
72   core <- loadModule filename
73   let [(b, expr)] = findBinds core [name]
74   putStr "\n"
75   putStr $ prettyShow expr
76   putStr "\n\n"
77   putStr $ showSDoc $ ppr expr
78   putStr "\n\n"
79   putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
80   putStr "\n\n"
81
82 -- | Translate the binds with the given names from the given core module to
83 --   VHDL. The Bool in the tuple makes the function stateful (True) or
84 --   stateless (False).
85 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
86 moduleToVHDL core list = do
87   let (names, statefuls) = unzip list
88   let binds = findBinds core names
89   -- Generate a UniqSupply
90   -- Running 
91   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
92   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
93   -- unique supply anywhere.
94   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
95   -- Turn bind into VHDL
96   let (vhdl, sess) = State.runState (mkVHDL uniqSupply binds statefuls) (TranslatorSession core 0 Map.empty)
97   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr . snd) vhdl
98   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
99   return vhdl
100   where
101     -- Turns the given bind into VHDL
102     mkVHDL :: UniqSupply.UniqSupply -> [(CoreBndr, CoreExpr)] -> [Bool] -> TranslatorState [(AST.VHDLId, AST.DesignFile)]
103     mkVHDL uniqSupply binds statefuls = do
104       let binds'' = map (Arrow.second $ normalize uniqSupply) binds
105       let binds' = trace ("Before:\n\n" ++ showSDoc ( ppr binds ) ++ "\n\nAfter:\n\n" ++ showSDoc ( ppr binds'')) binds''
106       -- Add the builtin functions
107       --mapM addBuiltIn builtin_funcs
108       -- Create entities and architectures for them
109       --Monad.zipWithM processBind statefuls binds
110       --modA tsFlatFuncs (Map.map nameFlatFunction)
111       --flatfuncs <- getA tsFlatFuncs
112       return $ VHDL.createDesignFiles binds'
113
114 -- | Write the given design file to a file with the given name inside the
115 --   given dir
116 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
117 writeVHDL dir (name, vhdl) = do
118   -- Create the dir if needed
119   exists <- Directory.doesDirectoryExist dir
120   Monad.unless exists $ Directory.createDirectory dir
121   -- Find the filename
122   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
123   -- Write the file
124   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
125
126 -- | Loads the given file and turns it into a core module.
127 loadModule :: String -> IO HscTypes.CoreModule
128 loadModule filename =
129   defaultErrorHandler defaultDynFlags $ do
130     runGhc (Just libdir) $ do
131       dflags <- getSessionDynFlags
132       setSessionDynFlags dflags
133       --target <- guessTarget "adder.hs" Nothing
134       --liftIO (print (showSDoc (ppr (target))))
135       --liftIO $ printTarget target
136       --setTargets [target]
137       --load LoadAllTargets
138       --core <- GHC.compileToCoreSimplified "Adders.hs"
139       core <- GHC.compileToCoreModule filename
140       return core
141
142 -- | Extracts the named binds from the given module.
143 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
144 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
145
146 -- | Extract a named bind from the given list of binds
147 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
148 findBind binds lookfor =
149   -- This ignores Recs and compares the name of the bind with lookfor,
150   -- disregarding any namespaces in OccName and extra attributes in Name and
151   -- Var.
152   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
153
154 -- | Processes the given bind as a top level bind.
155 processBind ::
156   Bool                       -- ^ Should this be stateful function?
157   -> (CoreBndr, CoreExpr)    -- ^ The bind to process
158   -> TranslatorState ()
159
160 processBind stateful bind@(var, expr) = do
161   -- Create the function signature
162   let ty = CoreUtils.exprType expr
163   let hsfunc = mkHsFunction var ty stateful
164   flattenBind hsfunc bind
165
166 -- | Flattens the given bind into the given signature and adds it to the
167 --   session. Then (recursively) finds any functions it uses and does the same
168 --   with them.
169 flattenBind ::
170   HsFunction                         -- The signature to flatten into
171   -> (CoreBndr, CoreExpr)            -- The bind to flatten
172   -> TranslatorState ()
173
174 flattenBind hsfunc bind@(var, expr) = do
175   -- Flatten the function
176   let flatfunc = flattenFunction hsfunc bind
177   -- Propagate state variables
178   let flatfunc' = propagateState hsfunc flatfunc
179   -- Store the flat function in the session
180   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
181   -- Flatten any functions used
182   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
183   mapM_ resolvFunc used_hsfuncs
184
185 -- | Decide which incoming state variables will become state in the
186 --   given function, and which will be propagate to other applied
187 --   functions.
188 propagateState ::
189   HsFunction
190   -> FlatFunction
191   -> FlatFunction
192
193 propagateState hsfunc flatfunc =
194     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
195   where
196     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
197     states' = zip olds news
198     -- Find all signals used by all sigdefs
199     uses = concatMap sigDefUses (flat_defs flatfunc)
200     -- Find all signals that are used more than once (is there a
201     -- prettier way to do this?)
202     multiple_uses = uses List.\\ (List.nub uses)
203     -- Find the states whose "old state" signal is used only once
204     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
205     -- See if these single use states can be propagated
206     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
207     substate_sigs = concat substate_sigss
208     -- Mark any propagated state signals as SigSubState
209     sigs' = map 
210       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
211       (flat_sigs flatfunc)
212
213 -- | Propagate the state into a single function application.
214 propagateState' ::
215   [(SignalId, SignalId)]
216                       -- ^ TODO
217   -> SigDef           -- ^ The SigDef to process.
218   -> ([SignalId], SigDef) 
219                       -- ^ Any signal ids that should become substates,
220                       --   and the resulting application.
221
222 propagateState' states def =
223     if (is_FApp def) then
224       (our_old ++ our_new, def {appFunc = hsfunc'})
225     else
226       ([], def)
227   where
228     hsfunc = appFunc def
229     args = appArgs def
230     res = appRes def
231     our_states = filter our_state states
232     -- A state signal belongs in this function if the old state is
233     -- passed in, and the new state returned
234     our_state (old, new) =
235       any (old `Foldable.elem`) args
236       && new `Foldable.elem` res
237     (our_old, our_new) = unzip our_states
238     -- Mark the result
239     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
240     res' = fmap (mark_state (zip our_new [0..])) zipped_res
241     -- Mark the args
242     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
243     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
244     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
245
246     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
247     mark_state states (id, use) =
248       case lookup id states of
249         Nothing -> use
250         Just state_id -> State state_id
251
252 -- | Returns pairs of signals that should be mapped to state in this function.
253 getStateSignals ::
254   HsFunction                      -- | The function to look at
255   -> FlatFunction                 -- | The function to look at
256   -> [(SignalId, SignalId)]   
257         -- | TODO The state signals. The first is the state number, the second the
258         --   signal to assign the current state to, the last is the signal
259         --   that holds the new state.
260
261 getStateSignals hsfunc flatfunc =
262   [(old_id, new_id) 
263     | (old_num, old_id) <- args
264     , (new_num, new_id) <- res
265     , old_num == new_num]
266   where
267     sigs = flat_sigs flatfunc
268     -- Translate args and res to lists of (statenum, sigid)
269     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
270     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
271     
272 -- | Find the given function, flatten it and add it to the session. Then
273 --   (recursively) do the same for any functions used.
274 resolvFunc ::
275   HsFunction        -- | The function to look for
276   -> TranslatorState ()
277
278 resolvFunc hsfunc = do
279   flatfuncmap <- getA tsFlatFuncs
280   -- Don't do anything if there is already a flat function for this hsfunc or
281   -- when it is a builtin function.
282   Monad.unless (Map.member hsfunc flatfuncmap) $ do
283   -- Not working with new builtins -- Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
284   -- New function, resolve it
285   core <- getA tsCoreModule
286   -- Find the named function
287   let name = (hsFuncName hsfunc)
288   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
289   case bind of
290     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
291     Just b  -> flattenBind hsfunc b
292
293 -- | Translate a top level function declaration to a HsFunction. i.e., which
294 --   interface will be provided by this function. This function essentially
295 --   defines the "calling convention" for hardware models.
296 mkHsFunction ::
297   Var.Var         -- ^ The function defined
298   -> Type         -- ^ The function type (including arguments!)
299   -> Bool         -- ^ Is this a stateful function?
300   -> HsFunction   -- ^ The resulting HsFunction
301
302 mkHsFunction f ty stateful=
303   HsFunction hsname hsargs hsres
304   where
305     hsname  = getOccString f
306     (arg_tys, res_ty) = Type.splitFunTys ty
307     (hsargs, hsres) = 
308       if stateful 
309       then
310         let
311           -- The last argument must be state
312           state_ty = last arg_tys
313           state    = useAsState (mkHsValueMap state_ty)
314           -- All but the last argument are inports
315           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
316           hsargs   = inports ++ [state]
317           hsres    = case splitTupleType res_ty of
318             -- Result type must be a two tuple (state, ports)
319             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
320               then
321                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
322               else
323                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
324             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
325         in
326           (hsargs, hsres)
327       else
328         -- Just use everything as a port
329         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
330
331 -- | Adds signal names to the given FlatFunction
332 nameFlatFunction ::
333   FlatFunction
334   -> FlatFunction
335
336 nameFlatFunction flatfunc =
337   -- Name the signals
338   let 
339     s = flat_sigs flatfunc
340     s' = map nameSignal s in
341   flatfunc { flat_sigs = s' }
342   where
343     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
344     nameSignal (id, info) =
345       let hints = nameHints info in
346       let parts = ("sig" : hints) ++ [show id] in
347       let name = concat $ List.intersperse "_" parts in
348       (id, info {sigName = Just name})
349
350 -- | Splits a tuple type into a list of element types, or Nothing if the type
351 --   is not a tuple type.
352 splitTupleType ::
353   Type              -- ^ The type to split
354   -> Maybe [Type]   -- ^ The tuples element types
355
356 splitTupleType ty =
357   case Type.splitTyConApp_maybe ty of
358     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
359       then
360         Just args
361       else
362         Nothing
363     Nothing -> Nothing
364
365 -- vim: set ts=8 sw=2 sts=2 expandtab: