Make the Alu example use 4-bit SizedWord as data.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import GHC hiding (loadModule, sigName)
5 import CoreSyn
6 import qualified CoreUtils
7 import qualified Var
8 import qualified Type
9 import qualified TyCon
10 import qualified DataCon
11 import qualified HscMain
12 import qualified SrcLoc
13 import qualified FastString
14 import qualified Maybe
15 import qualified Module
16 import qualified Data.Foldable as Foldable
17 import qualified Control.Monad.Trans.State as State
18 import Name
19 import qualified Data.Map as Map
20 import Data.Accessor
21 import Data.Generics
22 import NameEnv ( lookupNameEnv )
23 import qualified HscTypes
24 import HscTypes ( cm_binds, cm_types )
25 import MonadUtils ( liftIO )
26 import Outputable ( showSDoc, ppr )
27 import GHC.Paths ( libdir )
28 import DynFlags ( defaultDynFlags )
29 import List ( find )
30 import qualified List
31 import qualified Monad
32
33 -- The following modules come from the ForSyDe project. They are really
34 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
35 -- ForSyDe to get access to these modules.
36 import qualified ForSyDe.Backend.VHDL.AST as AST
37 import qualified ForSyDe.Backend.VHDL.Ppr
38 import qualified ForSyDe.Backend.VHDL.FileIO
39 import qualified ForSyDe.Backend.Ppr
40 -- This is needed for rendering the pretty printed VHDL
41 import Text.PrettyPrint.HughesPJ (render)
42
43 import TranslatorTypes
44 import HsValueMap
45 import Pretty
46 import Flatten
47 import FlattenTypes
48 import VHDLTypes
49 import qualified VHDL
50
51 main = do
52   makeVHDL "Alu.hs" "exec" True
53
54 makeVHDL :: String -> String -> Bool -> IO ()
55 makeVHDL filename name stateful = do
56   -- Load the module
57   core <- loadModule filename
58   -- Translate to VHDL
59   vhdl <- moduleToVHDL core [(name, stateful)]
60   -- Write VHDL to file
61   let dir = "../vhdl/vhdl/" ++ name ++ "/"
62   mapM (writeVHDL dir) vhdl
63   return ()
64
65 -- | Show the core structure of the given binds in the given file.
66 listBind :: String -> String -> IO ()
67 listBind filename name = do
68   core <- loadModule filename
69   let [bind] = findBinds core [name]
70   putStr "\n"
71   putStr $ prettyShow bind
72   putStr "\n\n"
73   putStr $ showSDoc $ ppr bind
74   putStr "\n\n"
75   case bind of
76     NonRec b expr -> do 
77       putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
78       putStr "\n\n"
79     otherwise -> return ()
80
81 -- | Translate the binds with the given names from the given core module to
82 --   VHDL. The Bool in the tuple makes the function stateful (True) or
83 --   stateless (False).
84 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
85 moduleToVHDL core list = do
86   let (names, statefuls) = unzip list
87   --liftIO $ putStr $ prettyShow (cm_binds core)
88   let binds = findBinds core names
89   --putStr $ prettyShow binds
90   -- Turn bind into VHDL
91   let (vhdl, sess) = State.runState (mkVHDL binds statefuls) (TranslatorSession core 0 Map.empty)
92   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr . snd) vhdl
93   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
94   return vhdl
95   where
96     -- Turns the given bind into VHDL
97     mkVHDL :: [CoreBind] -> [Bool] -> TranslatorState [(AST.VHDLId, AST.DesignFile)]
98     mkVHDL binds statefuls = do
99       -- Add the builtin functions
100       --mapM addBuiltIn builtin_funcs
101       -- Create entities and architectures for them
102       Monad.zipWithM processBind statefuls binds
103       modA tsFlatFuncs (Map.map nameFlatFunction)
104       flatfuncs <- getA tsFlatFuncs
105       return $ VHDL.createDesignFiles flatfuncs
106
107 -- | Write the given design file to a file with the given name inside the
108 --   given dir
109 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
110 writeVHDL dir (name, vhdl) = do
111   -- Create the dir if needed
112   exists <- Directory.doesDirectoryExist dir
113   Monad.unless exists $ Directory.createDirectory dir
114   -- Find the filename
115   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
116   -- Write the file
117   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
118
119 -- | Loads the given file and turns it into a core module.
120 loadModule :: String -> IO HscTypes.CoreModule
121 loadModule filename =
122   defaultErrorHandler defaultDynFlags $ do
123     runGhc (Just libdir) $ do
124       dflags <- getSessionDynFlags
125       setSessionDynFlags dflags
126       --target <- guessTarget "adder.hs" Nothing
127       --liftIO (print (showSDoc (ppr (target))))
128       --liftIO $ printTarget target
129       --setTargets [target]
130       --load LoadAllTargets
131       --core <- GHC.compileToCoreSimplified "Adders.hs"
132       core <- GHC.compileToCoreSimplified filename
133       return core
134
135 -- | Extracts the named binds from the given module.
136 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
137 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
138
139 -- | Extract a named bind from the given list of binds
140 findBind :: [CoreBind] -> String -> Maybe CoreBind
141 findBind binds lookfor =
142   -- This ignores Recs and compares the name of the bind with lookfor,
143   -- disregarding any namespaces in OccName and extra attributes in Name and
144   -- Var.
145   find (\b -> case b of 
146     Rec l -> False
147     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
148   ) binds
149
150 -- | Processes the given bind as a top level bind.
151 processBind ::
152   Bool                       -- ^ Should this be stateful function?
153   -> CoreBind                -- ^ The bind to process
154   -> TranslatorState ()
155
156 processBind _ (Rec _) = error "Recursive binders not supported"
157 processBind stateful bind@(NonRec var expr) = do
158   -- Create the function signature
159   let ty = CoreUtils.exprType expr
160   let hsfunc = mkHsFunction var ty stateful
161   flattenBind hsfunc bind
162
163 -- | Flattens the given bind into the given signature and adds it to the
164 --   session. Then (recursively) finds any functions it uses and does the same
165 --   with them.
166 flattenBind ::
167   HsFunction                         -- The signature to flatten into
168   -> CoreBind                        -- The bind to flatten
169   -> TranslatorState ()
170
171 flattenBind _ (Rec _) = error "Recursive binders not supported"
172
173 flattenBind hsfunc bind@(NonRec var expr) = do
174   -- Flatten the function
175   let flatfunc = flattenFunction hsfunc bind
176   -- Propagate state variables
177   let flatfunc' = propagateState hsfunc flatfunc
178   -- Store the flat function in the session
179   modA tsFlatFuncs (Map.insert hsfunc flatfunc)
180   -- Flatten any functions used
181   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
182   mapM_ resolvFunc used_hsfuncs
183
184 -- | Decide which incoming state variables will become state in the
185 --   given function, and which will be propagate to other applied
186 --   functions.
187 propagateState ::
188   HsFunction
189   -> FlatFunction
190   -> FlatFunction
191
192 propagateState hsfunc flatfunc =
193     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
194   where
195     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
196     states' = zip olds news
197     -- Find all signals used by all sigdefs
198     uses = concatMap sigDefUses (flat_defs flatfunc)
199     -- Find all signals that are used more than once (is there a
200     -- prettier way to do this?)
201     multiple_uses = uses List.\\ (List.nub uses)
202     -- Find the states whose "old state" signal is used only once
203     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
204     -- See if these single use states can be propagated
205     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
206     substate_sigs = concat substate_sigss
207     -- Mark any propagated state signals as SigSubState
208     sigs' = map 
209       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
210       (flat_sigs flatfunc)
211
212 -- | Propagate the state into a single function application.
213 propagateState' ::
214   [(SignalId, SignalId)]
215                       -- ^ TODO
216   -> SigDef           -- ^ The SigDef to process.
217   -> ([SignalId], SigDef) 
218                       -- ^ Any signal ids that should become substates,
219                       --   and the resulting application.
220
221 propagateState' states def =
222     if (is_FApp def) then
223       (our_old ++ our_new, def {appFunc = hsfunc'})
224     else
225       ([], def)
226   where
227     hsfunc = appFunc def
228     args = appArgs def
229     res = appRes def
230     our_states = filter our_state states
231     -- A state signal belongs in this function if the old state is
232     -- passed in, and the new state returned
233     our_state (old, new) =
234       any (old `Foldable.elem`) args
235       && new `Foldable.elem` res
236     (our_old, our_new) = unzip our_states
237     -- Mark the result
238     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
239     res' = fmap (mark_state (zip our_new [0..])) zipped_res
240     -- Mark the args
241     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
242     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
243     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
244
245     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
246     mark_state states (id, use) =
247       case lookup id states of
248         Nothing -> use
249         Just state_id -> State state_id
250
251 -- | Returns pairs of signals that should be mapped to state in this function.
252 getStateSignals ::
253   HsFunction                      -- | The function to look at
254   -> FlatFunction                 -- | The function to look at
255   -> [(SignalId, SignalId)]   
256         -- | TODO The state signals. The first is the state number, the second the
257         --   signal to assign the current state to, the last is the signal
258         --   that holds the new state.
259
260 getStateSignals hsfunc flatfunc =
261   [(old_id, new_id) 
262     | (old_num, old_id) <- args
263     , (new_num, new_id) <- res
264     , old_num == new_num]
265   where
266     sigs = flat_sigs flatfunc
267     -- Translate args and res to lists of (statenum, sigid)
268     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
269     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
270     
271 -- | Find the given function, flatten it and add it to the session. Then
272 --   (recursively) do the same for any functions used.
273 resolvFunc ::
274   HsFunction        -- | The function to look for
275   -> TranslatorState ()
276
277 resolvFunc hsfunc = do
278   flatfuncmap <- getA tsFlatFuncs
279   -- Don't do anything if there is already a flat function for this hsfunc or
280   -- when it is a builtin function.
281   Monad.unless (Map.member hsfunc flatfuncmap) $ do
282   Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
283   -- New function, resolve it
284   core <- getA tsCoreModule
285   -- Find the named function
286   let name = (hsFuncName hsfunc)
287   let bind = findBind (cm_binds core) name 
288   case bind of
289     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
290     Just b  -> flattenBind hsfunc b
291
292 -- | Translate a top level function declaration to a HsFunction. i.e., which
293 --   interface will be provided by this function. This function essentially
294 --   defines the "calling convention" for hardware models.
295 mkHsFunction ::
296   Var.Var         -- ^ The function defined
297   -> Type         -- ^ The function type (including arguments!)
298   -> Bool         -- ^ Is this a stateful function?
299   -> HsFunction   -- ^ The resulting HsFunction
300
301 mkHsFunction f ty stateful=
302   HsFunction hsname hsargs hsres
303   where
304     hsname  = getOccString f
305     (arg_tys, res_ty) = Type.splitFunTys ty
306     (hsargs, hsres) = 
307       if stateful 
308       then
309         let
310           -- The last argument must be state
311           state_ty = last arg_tys
312           state    = useAsState (mkHsValueMap state_ty)
313           -- All but the last argument are inports
314           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
315           hsargs   = inports ++ [state]
316           hsres    = case splitTupleType res_ty of
317             -- Result type must be a two tuple (state, ports)
318             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
319               then
320                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
321               else
322                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
323             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
324         in
325           (hsargs, hsres)
326       else
327         -- Just use everything as a port
328         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
329
330 -- | Adds signal names to the given FlatFunction
331 nameFlatFunction ::
332   FlatFunction
333   -> FlatFunction
334
335 nameFlatFunction flatfunc =
336   -- Name the signals
337   let 
338     s = flat_sigs flatfunc
339     s' = map nameSignal s in
340   flatfunc { flat_sigs = s' }
341   where
342     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
343     nameSignal (id, info) =
344       let hints = nameHints info in
345       let parts = ("sig" : hints) ++ [show id] in
346       let name = concat $ List.intersperse "_" parts in
347       (id, info {sigName = Just name})
348
349 -- | Splits a tuple type into a list of element types, or Nothing if the type
350 --   is not a tuple type.
351 splitTupleType ::
352   Type              -- ^ The type to split
353   -> Maybe [Type]   -- ^ The tuples element types
354
355 splitTupleType ty =
356   case Type.splitTyConApp_maybe ty of
357     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
358       then
359         Just args
360       else
361         Nothing
362     Nothing -> Nothing
363
364 -- vim: set ts=8 sw=2 sts=2 expandtab: