Don't inline alu.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import GHC hiding (loadModule, sigName)
5 import CoreSyn
6 import qualified CoreUtils
7 import qualified Var
8 import qualified Type
9 import qualified TyCon
10 import qualified DataCon
11 import qualified Maybe
12 import qualified Module
13 import qualified Control.Monad.State as State
14 import qualified Data.Foldable as Foldable
15 import Name
16 import qualified Data.Map as Map
17 import Data.Generics
18 import NameEnv ( lookupNameEnv )
19 import qualified HscTypes
20 import HscTypes ( cm_binds, cm_types )
21 import MonadUtils ( liftIO )
22 import Outputable ( showSDoc, ppr )
23 import GHC.Paths ( libdir )
24 import DynFlags ( defaultDynFlags )
25 import List ( find )
26 import qualified List
27 import qualified Monad
28
29 -- The following modules come from the ForSyDe project. They are really
30 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
31 -- ForSyDe to get access to these modules.
32 import qualified ForSyDe.Backend.VHDL.AST as AST
33 import qualified ForSyDe.Backend.VHDL.Ppr
34 import qualified ForSyDe.Backend.VHDL.FileIO
35 import qualified ForSyDe.Backend.Ppr
36 -- This is needed for rendering the pretty printed VHDL
37 import Text.PrettyPrint.HughesPJ (render)
38
39 import TranslatorTypes
40 import HsValueMap
41 import Pretty
42 import Flatten
43 import FlattenTypes
44 import VHDLTypes
45 import qualified VHDL
46
47 main = do
48   makeVHDL "Alu.hs" "register_bank" True
49
50 makeVHDL :: String -> String -> Bool -> IO ()
51 makeVHDL filename name stateful = do
52   -- Load the module
53   core <- loadModule filename
54   -- Translate to VHDL
55   vhdl <- moduleToVHDL core [(name, stateful)]
56   -- Write VHDL to file
57   let dir = "../vhdl/vhdl/" ++ name ++ "/"
58   mapM (writeVHDL dir) vhdl
59   return ()
60
61 -- | Show the core structure of the given binds in the given file.
62 listBind :: String -> String -> IO ()
63 listBind filename name = do
64   core <- loadModule filename
65   let binds = findBinds core [name]
66   putStr "\n"
67   putStr $ prettyShow binds
68   putStr "\n\n"
69   putStr $ showSDoc $ ppr binds
70   putStr "\n\n"
71
72 -- | Translate the binds with the given names from the given core module to
73 --   VHDL. The Bool in the tuple makes the function stateful (True) or
74 --   stateless (False).
75 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [AST.DesignFile]
76 moduleToVHDL core list = do
77   let (names, statefuls) = unzip list
78   --liftIO $ putStr $ prettyShow (cm_binds core)
79   let binds = findBinds core names
80   --putStr $ prettyShow binds
81   -- Turn bind into VHDL
82   let (vhdl, sess) = State.runState (mkVHDL binds statefuls) (VHDLSession core 0 Map.empty)
83   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr) vhdl
84   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
85   return vhdl
86
87   where
88     -- Turns the given bind into VHDL
89     mkVHDL binds statefuls = do
90       -- Add the builtin functions
91       mapM addBuiltIn builtin_funcs
92       -- Create entities and architectures for them
93       Monad.zipWithM processBind statefuls binds
94       modFuncs nameFlatFunction
95       modFuncs VHDL.createEntity
96       modFuncs VHDL.createArchitecture
97       VHDL.getDesignFiles
98
99 -- | Write the given design file to a file inside the given dir
100 --   The first library unit in the designfile must be an entity, whose name
101 --   will be used as a filename.
102 writeVHDL :: String -> AST.DesignFile -> IO ()
103 writeVHDL dir vhdl = do
104   -- Create the dir if needed
105   exists <- Directory.doesDirectoryExist dir
106   Monad.unless exists $ Directory.createDirectory dir
107   -- Find the filename
108   let AST.DesignFile _ (u:us) = vhdl
109   let AST.LUEntity (AST.EntityDec id _) = u
110   let fname = dir ++ AST.fromVHDLId id ++ ".vhdl"
111   -- Write the file
112   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
113
114 -- | Loads the given file and turns it into a core module.
115 loadModule :: String -> IO HscTypes.CoreModule
116 loadModule filename =
117   defaultErrorHandler defaultDynFlags $ do
118     runGhc (Just libdir) $ do
119       dflags <- getSessionDynFlags
120       setSessionDynFlags dflags
121       --target <- guessTarget "adder.hs" Nothing
122       --liftIO (print (showSDoc (ppr (target))))
123       --liftIO $ printTarget target
124       --setTargets [target]
125       --load LoadAllTargets
126       --core <- GHC.compileToCoreSimplified "Adders.hs"
127       core <- GHC.compileToCoreSimplified filename
128       return core
129
130 -- | Extracts the named binds from the given module.
131 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
132 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
133
134 -- | Extract a named bind from the given list of binds
135 findBind :: [CoreBind] -> String -> Maybe CoreBind
136 findBind binds lookfor =
137   -- This ignores Recs and compares the name of the bind with lookfor,
138   -- disregarding any namespaces in OccName and extra attributes in Name and
139   -- Var.
140   find (\b -> case b of 
141     Rec l -> False
142     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
143   ) binds
144
145 -- | Processes the given bind as a top level bind.
146 processBind ::
147   Bool                       -- ^ Should this be stateful function?
148   -> CoreBind                -- ^ The bind to process
149   -> VHDLState ()
150
151 processBind _ (Rec _) = error "Recursive binders not supported"
152 processBind stateful bind@(NonRec var expr) = do
153   -- Create the function signature
154   let ty = CoreUtils.exprType expr
155   let hsfunc = mkHsFunction var ty stateful
156   flattenBind hsfunc bind
157
158 -- | Flattens the given bind into the given signature and adds it to the
159 --   session. Then (recursively) finds any functions it uses and does the same
160 --   with them.
161 flattenBind ::
162   HsFunction                         -- The signature to flatten into
163   -> CoreBind                        -- The bind to flatten
164   -> VHDLState ()
165
166 flattenBind _ (Rec _) = error "Recursive binders not supported"
167
168 flattenBind hsfunc bind@(NonRec var expr) = do
169   -- Add the function to the session
170   addFunc hsfunc
171   -- Flatten the function
172   let flatfunc = flattenFunction hsfunc bind
173   -- Propagate state variables
174   let flatfunc' = propagateState hsfunc flatfunc
175   -- Store the flat function in the session
176   setFlatFunc hsfunc flatfunc'
177   -- Flatten any functions used
178   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
179   State.mapM resolvFunc used_hsfuncs
180   return ()
181
182 -- | Decide which incoming state variables will become state in the
183 --   given function, and which will be propagate to other applied
184 --   functions.
185 propagateState ::
186   HsFunction
187   -> FlatFunction
188   -> FlatFunction
189
190 propagateState hsfunc flatfunc =
191     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
192   where
193     apps = filter is_FApp (flat_defs flatfunc)
194     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
195     states' = zip olds news
196     -- Find all signals used by all sigdefs
197     uses = concatMap sigDefUses (flat_defs flatfunc)
198     -- Find all signals that are used more than once (is there a
199     -- prettier way to do this?)
200     multiple_uses = uses List.\\ (List.nub uses)
201     -- Find the states whose "old state" signal is used only once
202     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
203     -- See if these single use states can be propagated
204     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) apps
205     substate_sigs = concat substate_sigss
206     -- Mark any propagated state signals as SigSubState
207     sigs' = map 
208       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
209       (flat_sigs flatfunc)
210
211 -- | Propagate the state into a single function application.
212 propagateState' ::
213   [(SignalId, SignalId)]
214                       -- ^ TODO
215   -> SigDef           -- ^ The function application to process. Must be
216                       --   a FApp constructor.
217   -> ([SignalId], SigDef) 
218                       -- ^ Any signal ids that should become substates,
219                       --   and the resulting application.
220
221 propagateState' states app =
222     (our_old ++ our_new, app {appFunc = hsfunc'})
223   where
224     hsfunc = appFunc app
225     args = appArgs app
226     res = appRes app
227     our_states = filter our_state states
228     -- A state signal belongs in this function if the old state is
229     -- passed in, and the new state returned
230     our_state (old, new) =
231       any (old `Foldable.elem`) args
232       && new `Foldable.elem` res
233     (our_old, our_new) = unzip our_states
234     -- Mark the result
235     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
236     res' = fmap (mark_state (zip our_new [0..])) zipped_res
237     -- Mark the args
238     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
239     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
240     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
241
242     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
243     mark_state states (id, use) =
244       case lookup id states of
245         Nothing -> use
246         Just state_id -> State state_id
247
248 -- | Returns pairs of signals that should be mapped to state in this function.
249 getStateSignals ::
250   HsFunction                      -- | The function to look at
251   -> FlatFunction                 -- | The function to look at
252   -> [(SignalId, SignalId)]   
253         -- | TODO The state signals. The first is the state number, the second the
254         --   signal to assign the current state to, the last is the signal
255         --   that holds the new state.
256
257 getStateSignals hsfunc flatfunc =
258   [(old_id, new_id) 
259     | (old_num, old_id) <- args
260     , (new_num, new_id) <- res
261     , old_num == new_num]
262   where
263     sigs = flat_sigs flatfunc
264     -- Translate args and res to lists of (statenum, sigid)
265     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
266     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
267     
268 -- | Find the given function, flatten it and add it to the session. Then
269 --   (recursively) do the same for any functions used.
270 resolvFunc ::
271   HsFunction        -- | The function to look for
272   -> VHDLState ()
273
274 resolvFunc hsfunc = do
275   -- See if the function is already known
276   func <- getFunc hsfunc
277   case func of
278     -- Already known, do nothing
279     Just _ -> do
280       return ()
281     -- New function, resolve it
282     Nothing -> do
283       -- Get the current module
284       core <- getModule
285       -- Find the named function
286       let bind = findBind (cm_binds core) name
287       case bind of
288         Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
289         Just b  -> flattenBind hsfunc b
290   where
291     name = hsFuncName hsfunc
292
293 -- | Translate a top level function declaration to a HsFunction. i.e., which
294 --   interface will be provided by this function. This function essentially
295 --   defines the "calling convention" for hardware models.
296 mkHsFunction ::
297   Var.Var         -- ^ The function defined
298   -> Type         -- ^ The function type (including arguments!)
299   -> Bool         -- ^ Is this a stateful function?
300   -> HsFunction   -- ^ The resulting HsFunction
301
302 mkHsFunction f ty stateful=
303   HsFunction hsname hsargs hsres
304   where
305     hsname  = getOccString f
306     (arg_tys, res_ty) = Type.splitFunTys ty
307     (hsargs, hsres) = 
308       if stateful 
309       then
310         let
311           -- The last argument must be state
312           state_ty = last arg_tys
313           state    = useAsState (mkHsValueMap state_ty)
314           -- All but the last argument are inports
315           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
316           hsargs   = inports ++ [state]
317           hsres    = case splitTupleType res_ty of
318             -- Result type must be a two tuple (state, ports)
319             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
320               then
321                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
322               else
323                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
324             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
325         in
326           (hsargs, hsres)
327       else
328         -- Just use everything as a port
329         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
330
331 -- | Adds signal names to the given FlatFunction
332 nameFlatFunction ::
333   HsFunction
334   -> FuncData
335   -> VHDLState ()
336
337 nameFlatFunction hsfunc fdata =
338   let func = flatFunc fdata in
339   case func of
340     -- Skip (builtin) functions without a FlatFunction
341     Nothing -> do return ()
342     -- Name the signals in all other functions
343     Just flatfunc ->
344       let s = flat_sigs flatfunc in
345       let s' = map nameSignal s in
346       let flatfunc' = flatfunc { flat_sigs = s' } in
347       setFlatFunc hsfunc flatfunc'
348   where
349     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
350     nameSignal (id, info) =
351       let hints = nameHints info in
352       let parts = ("sig" : hints) ++ [show id] in
353       let name = concat $ List.intersperse "_" parts in
354       (id, info {sigName = Just name})
355
356 -- | Splits a tuple type into a list of element types, or Nothing if the type
357 --   is not a tuple type.
358 splitTupleType ::
359   Type              -- ^ The type to split
360   -> Maybe [Type]   -- ^ The tuples element types
361
362 splitTupleType ty =
363   case Type.splitTyConApp_maybe ty of
364     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
365       then
366         Just args
367       else
368         Nothing
369     Nothing -> Nothing
370
371 -- | A consise representation of a (set of) ports on a builtin function
372 type PortMap = HsValueMap (String, AST.TypeMark)
373 -- | A consise representation of a builtin function
374 data BuiltIn = BuiltIn String [PortMap] PortMap
375
376 -- | Map a port specification of a builtin function to a VHDL Signal to put in
377 --   a VHDLSignalMap
378 toVHDLSignalMap :: HsValueMap (String, AST.TypeMark) -> VHDLSignalMap
379 toVHDLSignalMap = fmap (\(name, ty) -> Just (VHDL.mkVHDLId name, ty))
380
381 -- | Translate a concise representation of a builtin function to something
382 --   that can be put into FuncMap directly.
383 addBuiltIn :: BuiltIn -> VHDLState ()
384 addBuiltIn (BuiltIn name args res) = do
385     addFunc hsfunc
386     setEntity hsfunc entity
387   where
388     hsfunc = HsFunction name (map useAsPort args) (useAsPort res)
389     entity = Entity (VHDL.mkVHDLId name) (map toVHDLSignalMap args) (toVHDLSignalMap res) Nothing
390
391 builtin_funcs = 
392   [ 
393     BuiltIn "hwxor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
394     BuiltIn "hwand" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
395     BuiltIn "hwor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
396     BuiltIn "hwnot" [(Single ("a", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty))
397   ]
398
399 -- vim: set ts=8 sw=2 sts=2 expandtab: