a94e3f44c39d0a7b8fcdbc99c232e76f52ef0bfa
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import GHC hiding (loadModule, sigName)
5 import CoreSyn
6 import qualified CoreUtils
7 import qualified Var
8 import qualified Type
9 import qualified TyCon
10 import qualified DataCon
11 import qualified Maybe
12 import qualified Module
13 import qualified Data.Foldable as Foldable
14 import qualified Control.Monad.Trans.State as State
15 import Name
16 import qualified Data.Map as Map
17 import Data.Accessor
18 import Data.Generics
19 import NameEnv ( lookupNameEnv )
20 import qualified HscTypes
21 import HscTypes ( cm_binds, cm_types )
22 import MonadUtils ( liftIO )
23 import Outputable ( showSDoc, ppr )
24 import GHC.Paths ( libdir )
25 import DynFlags ( defaultDynFlags )
26 import List ( find )
27 import qualified List
28 import qualified Monad
29
30 -- The following modules come from the ForSyDe project. They are really
31 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
32 -- ForSyDe to get access to these modules.
33 import qualified ForSyDe.Backend.VHDL.AST as AST
34 import qualified ForSyDe.Backend.VHDL.Ppr
35 import qualified ForSyDe.Backend.VHDL.FileIO
36 import qualified ForSyDe.Backend.Ppr
37 -- This is needed for rendering the pretty printed VHDL
38 import Text.PrettyPrint.HughesPJ (render)
39
40 import TranslatorTypes
41 import HsValueMap
42 import Pretty
43 import Flatten
44 import FlattenTypes
45 import VHDLTypes
46 import qualified VHDL
47
48 main = do
49   makeVHDL "Alu.hs" "register_bank" True
50
51 makeVHDL :: String -> String -> Bool -> IO ()
52 makeVHDL filename name stateful = do
53   -- Load the module
54   core <- loadModule filename
55   -- Translate to VHDL
56   vhdl <- moduleToVHDL core [(name, stateful)]
57   -- Write VHDL to file
58   let dir = "../vhdl/vhdl/" ++ name ++ "/"
59   mapM (writeVHDL dir) vhdl
60   return ()
61
62 -- | Show the core structure of the given binds in the given file.
63 listBind :: String -> String -> IO ()
64 listBind filename name = do
65   core <- loadModule filename
66   let binds = findBinds core [name]
67   putStr "\n"
68   putStr $ prettyShow binds
69   putStr "\n\n"
70   putStr $ showSDoc $ ppr binds
71   putStr "\n\n"
72
73 -- | Translate the binds with the given names from the given core module to
74 --   VHDL. The Bool in the tuple makes the function stateful (True) or
75 --   stateless (False).
76 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
77 moduleToVHDL core list = do
78   let (names, statefuls) = unzip list
79   --liftIO $ putStr $ prettyShow (cm_binds core)
80   let binds = findBinds core names
81   --putStr $ prettyShow binds
82   -- Turn bind into VHDL
83   let (vhdl, sess) = State.runState (mkVHDL binds statefuls) (TranslatorSession core 0 Map.empty)
84   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr . snd) vhdl
85   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
86   return vhdl
87   where
88     -- Turns the given bind into VHDL
89     mkVHDL :: [CoreBind] -> [Bool] -> TranslatorState [(AST.VHDLId, AST.DesignFile)]
90     mkVHDL binds statefuls = do
91       -- Add the builtin functions
92       --mapM addBuiltIn builtin_funcs
93       -- Create entities and architectures for them
94       Monad.zipWithM processBind statefuls binds
95       modA tsFlatFuncs (Map.map nameFlatFunction)
96       flatfuncs <- getA tsFlatFuncs
97       return $ VHDL.createDesignFiles flatfuncs
98
99 -- | Write the given design file to a file with the given name inside the
100 --   given dir
101 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
102 writeVHDL dir (name, vhdl) = do
103   -- Create the dir if needed
104   exists <- Directory.doesDirectoryExist dir
105   Monad.unless exists $ Directory.createDirectory dir
106   -- Find the filename
107   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
108   -- Write the file
109   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
110
111 -- | Loads the given file and turns it into a core module.
112 loadModule :: String -> IO HscTypes.CoreModule
113 loadModule filename =
114   defaultErrorHandler defaultDynFlags $ do
115     runGhc (Just libdir) $ do
116       dflags <- getSessionDynFlags
117       setSessionDynFlags dflags
118       --target <- guessTarget "adder.hs" Nothing
119       --liftIO (print (showSDoc (ppr (target))))
120       --liftIO $ printTarget target
121       --setTargets [target]
122       --load LoadAllTargets
123       --core <- GHC.compileToCoreSimplified "Adders.hs"
124       core <- GHC.compileToCoreSimplified filename
125       return core
126
127 -- | Extracts the named binds from the given module.
128 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
129 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
130
131 -- | Extract a named bind from the given list of binds
132 findBind :: [CoreBind] -> String -> Maybe CoreBind
133 findBind binds lookfor =
134   -- This ignores Recs and compares the name of the bind with lookfor,
135   -- disregarding any namespaces in OccName and extra attributes in Name and
136   -- Var.
137   find (\b -> case b of 
138     Rec l -> False
139     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
140   ) binds
141
142 -- | Processes the given bind as a top level bind.
143 processBind ::
144   Bool                       -- ^ Should this be stateful function?
145   -> CoreBind                -- ^ The bind to process
146   -> TranslatorState ()
147
148 processBind _ (Rec _) = error "Recursive binders not supported"
149 processBind stateful bind@(NonRec var expr) = do
150   -- Create the function signature
151   let ty = CoreUtils.exprType expr
152   let hsfunc = mkHsFunction var ty stateful
153   flattenBind hsfunc bind
154
155 -- | Flattens the given bind into the given signature and adds it to the
156 --   session. Then (recursively) finds any functions it uses and does the same
157 --   with them.
158 flattenBind ::
159   HsFunction                         -- The signature to flatten into
160   -> CoreBind                        -- The bind to flatten
161   -> TranslatorState ()
162
163 flattenBind _ (Rec _) = error "Recursive binders not supported"
164
165 flattenBind hsfunc bind@(NonRec var expr) = do
166   -- Flatten the function
167   let flatfunc = flattenFunction hsfunc bind
168   -- Propagate state variables
169   let flatfunc' = propagateState hsfunc flatfunc
170   -- Store the flat function in the session
171   modA tsFlatFuncs (Map.insert hsfunc flatfunc)
172   -- Flatten any functions used
173   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
174   mapM_ resolvFunc used_hsfuncs
175
176 -- | Decide which incoming state variables will become state in the
177 --   given function, and which will be propagate to other applied
178 --   functions.
179 propagateState ::
180   HsFunction
181   -> FlatFunction
182   -> FlatFunction
183
184 propagateState hsfunc flatfunc =
185     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
186   where
187     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
188     states' = zip olds news
189     -- Find all signals used by all sigdefs
190     uses = concatMap sigDefUses (flat_defs flatfunc)
191     -- Find all signals that are used more than once (is there a
192     -- prettier way to do this?)
193     multiple_uses = uses List.\\ (List.nub uses)
194     -- Find the states whose "old state" signal is used only once
195     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
196     -- See if these single use states can be propagated
197     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
198     substate_sigs = concat substate_sigss
199     -- Mark any propagated state signals as SigSubState
200     sigs' = map 
201       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
202       (flat_sigs flatfunc)
203
204 -- | Propagate the state into a single function application.
205 propagateState' ::
206   [(SignalId, SignalId)]
207                       -- ^ TODO
208   -> SigDef           -- ^ The SigDef to process.
209   -> ([SignalId], SigDef) 
210                       -- ^ Any signal ids that should become substates,
211                       --   and the resulting application.
212
213 propagateState' states def =
214     if (is_FApp def) then
215       (our_old ++ our_new, def {appFunc = hsfunc'})
216     else
217       ([], def)
218   where
219     hsfunc = appFunc def
220     args = appArgs def
221     res = appRes def
222     our_states = filter our_state states
223     -- A state signal belongs in this function if the old state is
224     -- passed in, and the new state returned
225     our_state (old, new) =
226       any (old `Foldable.elem`) args
227       && new `Foldable.elem` res
228     (our_old, our_new) = unzip our_states
229     -- Mark the result
230     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
231     res' = fmap (mark_state (zip our_new [0..])) zipped_res
232     -- Mark the args
233     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
234     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
235     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
236
237     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
238     mark_state states (id, use) =
239       case lookup id states of
240         Nothing -> use
241         Just state_id -> State state_id
242
243 -- | Returns pairs of signals that should be mapped to state in this function.
244 getStateSignals ::
245   HsFunction                      -- | The function to look at
246   -> FlatFunction                 -- | The function to look at
247   -> [(SignalId, SignalId)]   
248         -- | TODO The state signals. The first is the state number, the second the
249         --   signal to assign the current state to, the last is the signal
250         --   that holds the new state.
251
252 getStateSignals hsfunc flatfunc =
253   [(old_id, new_id) 
254     | (old_num, old_id) <- args
255     , (new_num, new_id) <- res
256     , old_num == new_num]
257   where
258     sigs = flat_sigs flatfunc
259     -- Translate args and res to lists of (statenum, sigid)
260     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
261     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
262     
263 -- | Find the given function, flatten it and add it to the session. Then
264 --   (recursively) do the same for any functions used.
265 resolvFunc ::
266   HsFunction        -- | The function to look for
267   -> TranslatorState ()
268
269 resolvFunc hsfunc = do
270   flatfuncmap <- getA tsFlatFuncs
271   -- Don't do anything if there is already a flat function for this hsfunc or
272   -- when it is a builtin function.
273   Monad.unless (Map.member hsfunc flatfuncmap) $ do
274   Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
275   -- TODO: Builtin functions
276   -- New function, resolve it
277   core <- getA tsCoreModule
278   -- Find the named function
279   let name = (hsFuncName hsfunc)
280   let bind = findBind (cm_binds core) name 
281   case bind of
282     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
283     Just b  -> flattenBind hsfunc b
284
285 -- | Translate a top level function declaration to a HsFunction. i.e., which
286 --   interface will be provided by this function. This function essentially
287 --   defines the "calling convention" for hardware models.
288 mkHsFunction ::
289   Var.Var         -- ^ The function defined
290   -> Type         -- ^ The function type (including arguments!)
291   -> Bool         -- ^ Is this a stateful function?
292   -> HsFunction   -- ^ The resulting HsFunction
293
294 mkHsFunction f ty stateful=
295   HsFunction hsname hsargs hsres
296   where
297     hsname  = getOccString f
298     (arg_tys, res_ty) = Type.splitFunTys ty
299     (hsargs, hsres) = 
300       if stateful 
301       then
302         let
303           -- The last argument must be state
304           state_ty = last arg_tys
305           state    = useAsState (mkHsValueMap state_ty)
306           -- All but the last argument are inports
307           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
308           hsargs   = inports ++ [state]
309           hsres    = case splitTupleType res_ty of
310             -- Result type must be a two tuple (state, ports)
311             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
312               then
313                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
314               else
315                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
316             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
317         in
318           (hsargs, hsres)
319       else
320         -- Just use everything as a port
321         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
322
323 -- | Adds signal names to the given FlatFunction
324 nameFlatFunction ::
325   FlatFunction
326   -> FlatFunction
327
328 nameFlatFunction flatfunc =
329   -- Name the signals
330   let 
331     s = flat_sigs flatfunc
332     s' = map nameSignal s in
333   flatfunc { flat_sigs = s' }
334   where
335     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
336     nameSignal (id, info) =
337       let hints = nameHints info in
338       let parts = ("sig" : hints) ++ [show id] in
339       let name = concat $ List.intersperse "_" parts in
340       (id, info {sigName = Just name})
341
342 -- | Splits a tuple type into a list of element types, or Nothing if the type
343 --   is not a tuple type.
344 splitTupleType ::
345   Type              -- ^ The type to split
346   -> Maybe [Type]   -- ^ The tuples element types
347
348 splitTupleType ty =
349   case Type.splitTyConApp_maybe ty of
350     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
351       then
352         Just args
353       else
354         Nothing
355     Nothing -> Nothing
356
357 -- vim: set ts=8 sw=2 sts=2 expandtab: