aa8d0d9946c2eb11263d509544eea1ba2efe9dd5
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import GHC hiding (loadModule, sigName)
5 import CoreSyn
6 import qualified CoreUtils
7 import qualified Var
8 import qualified Type
9 import qualified TyCon
10 import qualified DataCon
11 import qualified Maybe
12 import qualified Module
13 import qualified Control.Monad.State as State
14 import qualified Data.Foldable as Foldable
15 import Name
16 import qualified Data.Map as Map
17 import Data.Accessor
18 import Data.Generics
19 import NameEnv ( lookupNameEnv )
20 import qualified HscTypes
21 import HscTypes ( cm_binds, cm_types )
22 import MonadUtils ( liftIO )
23 import Outputable ( showSDoc, ppr )
24 import GHC.Paths ( libdir )
25 import DynFlags ( defaultDynFlags )
26 import List ( find )
27 import qualified List
28 import qualified Monad
29
30 -- The following modules come from the ForSyDe project. They are really
31 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
32 -- ForSyDe to get access to these modules.
33 import qualified ForSyDe.Backend.VHDL.AST as AST
34 import qualified ForSyDe.Backend.VHDL.Ppr
35 import qualified ForSyDe.Backend.VHDL.FileIO
36 import qualified ForSyDe.Backend.Ppr
37 -- This is needed for rendering the pretty printed VHDL
38 import Text.PrettyPrint.HughesPJ (render)
39
40 import TranslatorTypes
41 import HsValueMap
42 import Pretty
43 import Flatten
44 import FlattenTypes
45 import VHDLTypes
46 import qualified VHDL
47
48 main = do
49   makeVHDL "Alu.hs" "register_bank" True
50
51 makeVHDL :: String -> String -> Bool -> IO ()
52 makeVHDL filename name stateful = do
53   -- Load the module
54   core <- loadModule filename
55   -- Translate to VHDL
56   vhdl <- moduleToVHDL core [(name, stateful)]
57   -- Write VHDL to file
58   let dir = "../vhdl/vhdl/" ++ name ++ "/"
59   mapM (writeVHDL dir) vhdl
60   return ()
61
62 -- | Show the core structure of the given binds in the given file.
63 listBind :: String -> String -> IO ()
64 listBind filename name = do
65   core <- loadModule filename
66   let binds = findBinds core [name]
67   putStr "\n"
68   putStr $ prettyShow binds
69   putStr "\n\n"
70   putStr $ showSDoc $ ppr binds
71   putStr "\n\n"
72
73 -- | Translate the binds with the given names from the given core module to
74 --   VHDL. The Bool in the tuple makes the function stateful (True) or
75 --   stateless (False).
76 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [AST.DesignFile]
77 moduleToVHDL core list = do
78   let (names, statefuls) = unzip list
79   --liftIO $ putStr $ prettyShow (cm_binds core)
80   let binds = findBinds core names
81   --putStr $ prettyShow binds
82   -- Turn bind into VHDL
83   let (vhdl, sess) = State.runState (mkVHDL binds statefuls) (VHDLSession core 0 Map.empty)
84   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr) vhdl
85   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
86   return vhdl
87
88   where
89     -- Turns the given bind into VHDL
90     mkVHDL binds statefuls = do
91       -- Add the builtin functions
92       mapM addBuiltIn builtin_funcs
93       -- Create entities and architectures for them
94       Monad.zipWithM processBind statefuls binds
95       modFuncMap $ Map.map (fdFlatFunc ^: (fmap nameFlatFunction))
96       modFuncMap $ Map.mapWithKey (\hsfunc fdata -> fdEntity ^= (VHDL.createEntity hsfunc fdata) $ fdata)
97       funcs <- getFuncMap
98       modFuncMap $ Map.mapWithKey (\hsfunc fdata -> fdArch ^= (VHDL.createArchitecture funcs hsfunc fdata) $ fdata)
99       funcs <- getFuncs
100       return $ VHDL.getDesignFiles (map snd funcs)
101
102 -- | Write the given design file to a file inside the given dir
103 --   The first library unit in the designfile must be an entity, whose name
104 --   will be used as a filename.
105 writeVHDL :: String -> AST.DesignFile -> IO ()
106 writeVHDL dir vhdl = do
107   -- Create the dir if needed
108   exists <- Directory.doesDirectoryExist dir
109   Monad.unless exists $ Directory.createDirectory dir
110   -- Find the filename
111   let AST.DesignFile _ (u:us) = vhdl
112   let AST.LUEntity (AST.EntityDec id _) = u
113   let fname = dir ++ AST.fromVHDLId id ++ ".vhdl"
114   -- Write the file
115   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
116
117 -- | Loads the given file and turns it into a core module.
118 loadModule :: String -> IO HscTypes.CoreModule
119 loadModule filename =
120   defaultErrorHandler defaultDynFlags $ do
121     runGhc (Just libdir) $ do
122       dflags <- getSessionDynFlags
123       setSessionDynFlags dflags
124       --target <- guessTarget "adder.hs" Nothing
125       --liftIO (print (showSDoc (ppr (target))))
126       --liftIO $ printTarget target
127       --setTargets [target]
128       --load LoadAllTargets
129       --core <- GHC.compileToCoreSimplified "Adders.hs"
130       core <- GHC.compileToCoreSimplified filename
131       return core
132
133 -- | Extracts the named binds from the given module.
134 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
135 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
136
137 -- | Extract a named bind from the given list of binds
138 findBind :: [CoreBind] -> String -> Maybe CoreBind
139 findBind binds lookfor =
140   -- This ignores Recs and compares the name of the bind with lookfor,
141   -- disregarding any namespaces in OccName and extra attributes in Name and
142   -- Var.
143   find (\b -> case b of 
144     Rec l -> False
145     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
146   ) binds
147
148 -- | Processes the given bind as a top level bind.
149 processBind ::
150   Bool                       -- ^ Should this be stateful function?
151   -> CoreBind                -- ^ The bind to process
152   -> TranslatorState ()
153
154 processBind _ (Rec _) = error "Recursive binders not supported"
155 processBind stateful bind@(NonRec var expr) = do
156   -- Create the function signature
157   let ty = CoreUtils.exprType expr
158   let hsfunc = mkHsFunction var ty stateful
159   flattenBind hsfunc bind
160
161 -- | Flattens the given bind into the given signature and adds it to the
162 --   session. Then (recursively) finds any functions it uses and does the same
163 --   with them.
164 flattenBind ::
165   HsFunction                         -- The signature to flatten into
166   -> CoreBind                        -- The bind to flatten
167   -> TranslatorState ()
168
169 flattenBind _ (Rec _) = error "Recursive binders not supported"
170
171 flattenBind hsfunc bind@(NonRec var expr) = do
172   -- Add the function to the session
173   addFunc hsfunc
174   -- Flatten the function
175   let flatfunc = flattenFunction hsfunc bind
176   -- Propagate state variables
177   let flatfunc' = propagateState hsfunc flatfunc
178   -- Store the flat function in the session
179   setFlatFunc hsfunc flatfunc'
180   -- Flatten any functions used
181   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
182   State.mapM resolvFunc used_hsfuncs
183   return ()
184
185 -- | Decide which incoming state variables will become state in the
186 --   given function, and which will be propagate to other applied
187 --   functions.
188 propagateState ::
189   HsFunction
190   -> FlatFunction
191   -> FlatFunction
192
193 propagateState hsfunc flatfunc =
194     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
195   where
196     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
197     states' = zip olds news
198     -- Find all signals used by all sigdefs
199     uses = concatMap sigDefUses (flat_defs flatfunc)
200     -- Find all signals that are used more than once (is there a
201     -- prettier way to do this?)
202     multiple_uses = uses List.\\ (List.nub uses)
203     -- Find the states whose "old state" signal is used only once
204     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
205     -- See if these single use states can be propagated
206     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
207     substate_sigs = concat substate_sigss
208     -- Mark any propagated state signals as SigSubState
209     sigs' = map 
210       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
211       (flat_sigs flatfunc)
212
213 -- | Propagate the state into a single function application.
214 propagateState' ::
215   [(SignalId, SignalId)]
216                       -- ^ TODO
217   -> SigDef           -- ^ The SigDef to process.
218   -> ([SignalId], SigDef) 
219                       -- ^ Any signal ids that should become substates,
220                       --   and the resulting application.
221
222 propagateState' states def =
223     if (is_FApp def) then
224       (our_old ++ our_new, def {appFunc = hsfunc'})
225     else
226       ([], def)
227   where
228     hsfunc = appFunc def
229     args = appArgs def
230     res = appRes def
231     our_states = filter our_state states
232     -- A state signal belongs in this function if the old state is
233     -- passed in, and the new state returned
234     our_state (old, new) =
235       any (old `Foldable.elem`) args
236       && new `Foldable.elem` res
237     (our_old, our_new) = unzip our_states
238     -- Mark the result
239     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
240     res' = fmap (mark_state (zip our_new [0..])) zipped_res
241     -- Mark the args
242     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
243     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
244     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
245
246     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
247     mark_state states (id, use) =
248       case lookup id states of
249         Nothing -> use
250         Just state_id -> State state_id
251
252 -- | Returns pairs of signals that should be mapped to state in this function.
253 getStateSignals ::
254   HsFunction                      -- | The function to look at
255   -> FlatFunction                 -- | The function to look at
256   -> [(SignalId, SignalId)]   
257         -- | TODO The state signals. The first is the state number, the second the
258         --   signal to assign the current state to, the last is the signal
259         --   that holds the new state.
260
261 getStateSignals hsfunc flatfunc =
262   [(old_id, new_id) 
263     | (old_num, old_id) <- args
264     , (new_num, new_id) <- res
265     , old_num == new_num]
266   where
267     sigs = flat_sigs flatfunc
268     -- Translate args and res to lists of (statenum, sigid)
269     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
270     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
271     
272 -- | Find the given function, flatten it and add it to the session. Then
273 --   (recursively) do the same for any functions used.
274 resolvFunc ::
275   HsFunction        -- | The function to look for
276   -> TranslatorState ()
277
278 resolvFunc hsfunc = do
279   -- See if the function is already known
280   func <- getFunc hsfunc
281   case func of
282     -- Already known, do nothing
283     Just _ -> do
284       return ()
285     -- New function, resolve it
286     Nothing -> do
287       -- Get the current module
288       core <- getModule
289       -- Find the named function
290       let bind = findBind (cm_binds core) name
291       case bind of
292         Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
293         Just b  -> flattenBind hsfunc b
294   where
295     name = hsFuncName hsfunc
296
297 -- | Translate a top level function declaration to a HsFunction. i.e., which
298 --   interface will be provided by this function. This function essentially
299 --   defines the "calling convention" for hardware models.
300 mkHsFunction ::
301   Var.Var         -- ^ The function defined
302   -> Type         -- ^ The function type (including arguments!)
303   -> Bool         -- ^ Is this a stateful function?
304   -> HsFunction   -- ^ The resulting HsFunction
305
306 mkHsFunction f ty stateful=
307   HsFunction hsname hsargs hsres
308   where
309     hsname  = getOccString f
310     (arg_tys, res_ty) = Type.splitFunTys ty
311     (hsargs, hsres) = 
312       if stateful 
313       then
314         let
315           -- The last argument must be state
316           state_ty = last arg_tys
317           state    = useAsState (mkHsValueMap state_ty)
318           -- All but the last argument are inports
319           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
320           hsargs   = inports ++ [state]
321           hsres    = case splitTupleType res_ty of
322             -- Result type must be a two tuple (state, ports)
323             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
324               then
325                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
326               else
327                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
328             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
329         in
330           (hsargs, hsres)
331       else
332         -- Just use everything as a port
333         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
334
335 -- | Adds signal names to the given FlatFunction
336 nameFlatFunction ::
337   FlatFunction
338   -> FlatFunction
339
340 nameFlatFunction flatfunc =
341   -- Name the signals
342   let 
343     s = flat_sigs flatfunc
344     s' = map nameSignal s in
345   flatfunc { flat_sigs = s' }
346   where
347     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
348     nameSignal (id, info) =
349       let hints = nameHints info in
350       let parts = ("sig" : hints) ++ [show id] in
351       let name = concat $ List.intersperse "_" parts in
352       (id, info {sigName = Just name})
353
354 -- | Splits a tuple type into a list of element types, or Nothing if the type
355 --   is not a tuple type.
356 splitTupleType ::
357   Type              -- ^ The type to split
358   -> Maybe [Type]   -- ^ The tuples element types
359
360 splitTupleType ty =
361   case Type.splitTyConApp_maybe ty of
362     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
363       then
364         Just args
365       else
366         Nothing
367     Nothing -> Nothing
368
369 -- | A consise representation of a (set of) ports on a builtin function
370 type PortMap = HsValueMap (String, AST.TypeMark)
371 -- | A consise representation of a builtin function
372 data BuiltIn = BuiltIn String [PortMap] PortMap
373
374 -- | Map a port specification of a builtin function to a VHDL Signal to put in
375 --   a VHDLSignalMap
376 toVHDLSignalMap :: HsValueMap (String, AST.TypeMark) -> VHDLSignalMap
377 toVHDLSignalMap = fmap (\(name, ty) -> Just (VHDL.mkVHDLId name, ty))
378
379 -- | Translate a concise representation of a builtin function to something
380 --   that can be put into FuncMap directly.
381 addBuiltIn :: BuiltIn -> TranslatorState ()
382 addBuiltIn (BuiltIn name args res) = do
383     addFunc hsfunc
384     setEntity hsfunc entity
385   where
386     hsfunc = HsFunction name (map useAsPort args) (useAsPort res)
387     entity = Entity (VHDL.mkVHDLId name) (map toVHDLSignalMap args) (toVHDLSignalMap res) Nothing Nothing
388
389 builtin_funcs = 
390   [ 
391     BuiltIn "hwxor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
392     BuiltIn "hwand" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
393     BuiltIn "hwor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
394     BuiltIn "hwnot" [(Single ("a", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty))
395   ]
396
397 -- vim: set ts=8 sw=2 sts=2 expandtab: