841f63b892c2f330adfa88bca5c837cf5a7ada23
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import GHC hiding (loadModule, sigName)
5 import CoreSyn
6 import qualified CoreUtils
7 import qualified Var
8 import qualified Type
9 import qualified TyCon
10 import qualified DataCon
11 import qualified Maybe
12 import qualified Module
13 import qualified Control.Monad.State as State
14 import qualified Data.Foldable as Foldable
15 import Name
16 import qualified Data.Map as Map
17 import Data.Generics
18 import NameEnv ( lookupNameEnv )
19 import qualified HscTypes
20 import HscTypes ( cm_binds, cm_types )
21 import MonadUtils ( liftIO )
22 import Outputable ( showSDoc, ppr )
23 import GHC.Paths ( libdir )
24 import DynFlags ( defaultDynFlags )
25 import List ( find )
26 import qualified List
27 import qualified Monad
28
29 -- The following modules come from the ForSyDe project. They are really
30 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
31 -- ForSyDe to get access to these modules.
32 import qualified ForSyDe.Backend.VHDL.AST as AST
33 import qualified ForSyDe.Backend.VHDL.Ppr
34 import qualified ForSyDe.Backend.VHDL.FileIO
35 import qualified ForSyDe.Backend.Ppr
36 -- This is needed for rendering the pretty printed VHDL
37 import Text.PrettyPrint.HughesPJ (render)
38
39 import TranslatorTypes
40 import HsValueMap
41 import Pretty
42 import Flatten
43 import FlattenTypes
44 import VHDLTypes
45 import qualified VHDL
46
47 main = do
48   makeVHDL "Alu.hs" "register_bank" True
49
50 makeVHDL :: String -> String -> Bool -> IO ()
51 makeVHDL filename name stateful = do
52   -- Load the module
53   core <- loadModule filename
54   -- Translate to VHDL
55   vhdl <- moduleToVHDL core [(name, stateful)]
56   -- Write VHDL to file
57   let dir = "../vhdl/vhdl/" ++ name ++ "/"
58   mapM (writeVHDL dir) vhdl
59   return ()
60
61 -- | Show the core structure of the given binds in the given file.
62 listBind :: String -> String -> IO ()
63 listBind filename name = do
64   core <- loadModule filename
65   let binds = findBinds core [name]
66   putStr "\n"
67   putStr $ prettyShow binds
68   putStr "\n\n"
69   putStr $ showSDoc $ ppr binds
70   putStr "\n\n"
71
72 -- | Translate the binds with the given names from the given core module to
73 --   VHDL. The Bool in the tuple makes the function stateful (True) or
74 --   stateless (False).
75 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [AST.DesignFile]
76 moduleToVHDL core list = do
77   let (names, statefuls) = unzip list
78   --liftIO $ putStr $ prettyShow (cm_binds core)
79   let binds = findBinds core names
80   --putStr $ prettyShow binds
81   -- Turn bind into VHDL
82   let (vhdl, sess) = State.runState (mkVHDL binds statefuls) (VHDLSession core 0 Map.empty)
83   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr) vhdl
84   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
85   return vhdl
86
87   where
88     -- Turns the given bind into VHDL
89     mkVHDL binds statefuls = do
90       -- Add the builtin functions
91       mapM addBuiltIn builtin_funcs
92       -- Create entities and architectures for them
93       Monad.zipWithM processBind statefuls binds
94       modFuncs nameFlatFunction
95       modFuncs VHDL.createEntity
96       modFuncs VHDL.createArchitecture
97       VHDL.getDesignFiles
98
99 -- | Write the given design file to a file inside the given dir
100 --   The first library unit in the designfile must be an entity, whose name
101 --   will be used as a filename.
102 writeVHDL :: String -> AST.DesignFile -> IO ()
103 writeVHDL dir vhdl = do
104   -- Create the dir if needed
105   exists <- Directory.doesDirectoryExist dir
106   Monad.unless exists $ Directory.createDirectory dir
107   -- Find the filename
108   let AST.DesignFile _ (u:us) = vhdl
109   let AST.LUEntity (AST.EntityDec id _) = u
110   let fname = dir ++ AST.fromVHDLId id ++ ".vhdl"
111   -- Write the file
112   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
113
114 -- | Loads the given file and turns it into a core module.
115 loadModule :: String -> IO HscTypes.CoreModule
116 loadModule filename =
117   defaultErrorHandler defaultDynFlags $ do
118     runGhc (Just libdir) $ do
119       dflags <- getSessionDynFlags
120       setSessionDynFlags dflags
121       --target <- guessTarget "adder.hs" Nothing
122       --liftIO (print (showSDoc (ppr (target))))
123       --liftIO $ printTarget target
124       --setTargets [target]
125       --load LoadAllTargets
126       --core <- GHC.compileToCoreSimplified "Adders.hs"
127       core <- GHC.compileToCoreSimplified filename
128       return core
129
130 -- | Extracts the named binds from the given module.
131 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
132 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
133
134 -- | Extract a named bind from the given list of binds
135 findBind :: [CoreBind] -> String -> Maybe CoreBind
136 findBind binds lookfor =
137   -- This ignores Recs and compares the name of the bind with lookfor,
138   -- disregarding any namespaces in OccName and extra attributes in Name and
139   -- Var.
140   find (\b -> case b of 
141     Rec l -> False
142     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
143   ) binds
144
145 -- | Processes the given bind as a top level bind.
146 processBind ::
147   Bool                       -- ^ Should this be stateful function?
148   -> CoreBind                -- ^ The bind to process
149   -> VHDLState ()
150
151 processBind _ (Rec _) = error "Recursive binders not supported"
152 processBind stateful bind@(NonRec var expr) = do
153   -- Create the function signature
154   let ty = CoreUtils.exprType expr
155   let hsfunc = mkHsFunction var ty stateful
156   flattenBind hsfunc bind
157
158 -- | Flattens the given bind into the given signature and adds it to the
159 --   session. Then (recursively) finds any functions it uses and does the same
160 --   with them.
161 flattenBind ::
162   HsFunction                         -- The signature to flatten into
163   -> CoreBind                        -- The bind to flatten
164   -> VHDLState ()
165
166 flattenBind _ (Rec _) = error "Recursive binders not supported"
167
168 flattenBind hsfunc bind@(NonRec var expr) = do
169   -- Add the function to the session
170   addFunc hsfunc
171   -- Flatten the function
172   let flatfunc = flattenFunction hsfunc bind
173   -- Propagate state variables
174   let flatfunc' = propagateState hsfunc flatfunc
175   -- Store the flat function in the session
176   setFlatFunc hsfunc flatfunc'
177   -- Flatten any functions used
178   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
179   State.mapM resolvFunc used_hsfuncs
180   return ()
181
182 -- | Decide which incoming state variables will become state in the
183 --   given function, and which will be propagate to other applied
184 --   functions.
185 propagateState ::
186   HsFunction
187   -> FlatFunction
188   -> FlatFunction
189
190 propagateState hsfunc flatfunc =
191     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
192   where
193     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
194     states' = zip olds news
195     -- Find all signals used by all sigdefs
196     uses = concatMap sigDefUses (flat_defs flatfunc)
197     -- Find all signals that are used more than once (is there a
198     -- prettier way to do this?)
199     multiple_uses = uses List.\\ (List.nub uses)
200     -- Find the states whose "old state" signal is used only once
201     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
202     -- See if these single use states can be propagated
203     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
204     substate_sigs = concat substate_sigss
205     -- Mark any propagated state signals as SigSubState
206     sigs' = map 
207       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
208       (flat_sigs flatfunc)
209
210 -- | Propagate the state into a single function application.
211 propagateState' ::
212   [(SignalId, SignalId)]
213                       -- ^ TODO
214   -> SigDef           -- ^ The SigDef to process.
215   -> ([SignalId], SigDef) 
216                       -- ^ Any signal ids that should become substates,
217                       --   and the resulting application.
218
219 propagateState' states def =
220     if (is_FApp def) then
221       (our_old ++ our_new, def {appFunc = hsfunc'})
222     else
223       ([], def)
224   where
225     hsfunc = appFunc def
226     args = appArgs def
227     res = appRes def
228     our_states = filter our_state states
229     -- A state signal belongs in this function if the old state is
230     -- passed in, and the new state returned
231     our_state (old, new) =
232       any (old `Foldable.elem`) args
233       && new `Foldable.elem` res
234     (our_old, our_new) = unzip our_states
235     -- Mark the result
236     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
237     res' = fmap (mark_state (zip our_new [0..])) zipped_res
238     -- Mark the args
239     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
240     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
241     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
242
243     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
244     mark_state states (id, use) =
245       case lookup id states of
246         Nothing -> use
247         Just state_id -> State state_id
248
249 -- | Returns pairs of signals that should be mapped to state in this function.
250 getStateSignals ::
251   HsFunction                      -- | The function to look at
252   -> FlatFunction                 -- | The function to look at
253   -> [(SignalId, SignalId)]   
254         -- | TODO The state signals. The first is the state number, the second the
255         --   signal to assign the current state to, the last is the signal
256         --   that holds the new state.
257
258 getStateSignals hsfunc flatfunc =
259   [(old_id, new_id) 
260     | (old_num, old_id) <- args
261     , (new_num, new_id) <- res
262     , old_num == new_num]
263   where
264     sigs = flat_sigs flatfunc
265     -- Translate args and res to lists of (statenum, sigid)
266     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
267     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
268     
269 -- | Find the given function, flatten it and add it to the session. Then
270 --   (recursively) do the same for any functions used.
271 resolvFunc ::
272   HsFunction        -- | The function to look for
273   -> VHDLState ()
274
275 resolvFunc hsfunc = do
276   -- See if the function is already known
277   func <- getFunc hsfunc
278   case func of
279     -- Already known, do nothing
280     Just _ -> do
281       return ()
282     -- New function, resolve it
283     Nothing -> do
284       -- Get the current module
285       core <- getModule
286       -- Find the named function
287       let bind = findBind (cm_binds core) name
288       case bind of
289         Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
290         Just b  -> flattenBind hsfunc b
291   where
292     name = hsFuncName hsfunc
293
294 -- | Translate a top level function declaration to a HsFunction. i.e., which
295 --   interface will be provided by this function. This function essentially
296 --   defines the "calling convention" for hardware models.
297 mkHsFunction ::
298   Var.Var         -- ^ The function defined
299   -> Type         -- ^ The function type (including arguments!)
300   -> Bool         -- ^ Is this a stateful function?
301   -> HsFunction   -- ^ The resulting HsFunction
302
303 mkHsFunction f ty stateful=
304   HsFunction hsname hsargs hsres
305   where
306     hsname  = getOccString f
307     (arg_tys, res_ty) = Type.splitFunTys ty
308     (hsargs, hsres) = 
309       if stateful 
310       then
311         let
312           -- The last argument must be state
313           state_ty = last arg_tys
314           state    = useAsState (mkHsValueMap state_ty)
315           -- All but the last argument are inports
316           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
317           hsargs   = inports ++ [state]
318           hsres    = case splitTupleType res_ty of
319             -- Result type must be a two tuple (state, ports)
320             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
321               then
322                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
323               else
324                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
325             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
326         in
327           (hsargs, hsres)
328       else
329         -- Just use everything as a port
330         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
331
332 -- | Adds signal names to the given FlatFunction
333 nameFlatFunction ::
334   HsFunction
335   -> FuncData
336   -> VHDLState ()
337
338 nameFlatFunction hsfunc fdata =
339   let func = flatFunc fdata in
340   case func of
341     -- Skip (builtin) functions without a FlatFunction
342     Nothing -> do return ()
343     -- Name the signals in all other functions
344     Just flatfunc ->
345       let s = flat_sigs flatfunc in
346       let s' = map nameSignal s in
347       let flatfunc' = flatfunc { flat_sigs = s' } in
348       setFlatFunc hsfunc flatfunc'
349   where
350     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
351     nameSignal (id, info) =
352       let hints = nameHints info in
353       let parts = ("sig" : hints) ++ [show id] in
354       let name = concat $ List.intersperse "_" parts in
355       (id, info {sigName = Just name})
356
357 -- | Splits a tuple type into a list of element types, or Nothing if the type
358 --   is not a tuple type.
359 splitTupleType ::
360   Type              -- ^ The type to split
361   -> Maybe [Type]   -- ^ The tuples element types
362
363 splitTupleType ty =
364   case Type.splitTyConApp_maybe ty of
365     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
366       then
367         Just args
368       else
369         Nothing
370     Nothing -> Nothing
371
372 -- | A consise representation of a (set of) ports on a builtin function
373 type PortMap = HsValueMap (String, AST.TypeMark)
374 -- | A consise representation of a builtin function
375 data BuiltIn = BuiltIn String [PortMap] PortMap
376
377 -- | Map a port specification of a builtin function to a VHDL Signal to put in
378 --   a VHDLSignalMap
379 toVHDLSignalMap :: HsValueMap (String, AST.TypeMark) -> VHDLSignalMap
380 toVHDLSignalMap = fmap (\(name, ty) -> Just (VHDL.mkVHDLId name, ty))
381
382 -- | Translate a concise representation of a builtin function to something
383 --   that can be put into FuncMap directly.
384 addBuiltIn :: BuiltIn -> VHDLState ()
385 addBuiltIn (BuiltIn name args res) = do
386     addFunc hsfunc
387     setEntity hsfunc entity
388   where
389     hsfunc = HsFunction name (map useAsPort args) (useAsPort res)
390     entity = Entity (VHDL.mkVHDLId name) (map toVHDLSignalMap args) (toVHDLSignalMap res) Nothing
391
392 builtin_funcs = 
393   [ 
394     BuiltIn "hwxor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
395     BuiltIn "hwand" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
396     BuiltIn "hwor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
397     BuiltIn "hwnot" [(Single ("a", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty))
398   ]
399
400 -- vim: set ts=8 sw=2 sts=2 expandtab: