Started adding builtin functions
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import GHC hiding (loadModule, sigName)
5 import CoreSyn
6 import qualified CoreUtils
7 import qualified Var
8 import qualified Type
9 import qualified TyCon
10 import qualified DataCon
11 import qualified HscMain
12 import qualified SrcLoc
13 import qualified FastString
14 import qualified Maybe
15 import qualified Module
16 import qualified Data.Foldable as Foldable
17 import qualified Control.Monad.Trans.State as State
18 import Name
19 import qualified Data.Map as Map
20 import Data.Accessor
21 import Data.Generics
22 import NameEnv ( lookupNameEnv )
23 import qualified HscTypes
24 import HscTypes ( cm_binds, cm_types )
25 import MonadUtils ( liftIO )
26 import Outputable ( showSDoc, ppr )
27 import GHC.Paths ( libdir )
28 import DynFlags ( defaultDynFlags )
29 import List ( find )
30 import qualified List
31 import qualified Monad
32
33 -- The following modules come from the ForSyDe project. They are really
34 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
35 -- ForSyDe to get access to these modules.
36 import qualified ForSyDe.Backend.VHDL.AST as AST
37 import qualified ForSyDe.Backend.VHDL.Ppr
38 import qualified ForSyDe.Backend.VHDL.FileIO
39 import qualified ForSyDe.Backend.Ppr
40 -- This is needed for rendering the pretty printed VHDL
41 import Text.PrettyPrint.HughesPJ (render)
42
43 import TranslatorTypes
44 import HsValueMap
45 import Pretty
46 import Flatten
47 import FlattenTypes
48 import VHDLTypes
49 import qualified VHDL
50
51 -- main = do
52 --   makeVHDL "Alu.hs" "exec" True
53
54 makeVHDL :: String -> String -> Bool -> IO ()
55 makeVHDL filename name stateful = do
56   -- Load the module
57   core <- loadModule filename
58   -- Translate to VHDL
59   vhdl <- moduleToVHDL core [(name, stateful)]
60   -- Write VHDL to file
61   let dir = "./vhdl/" ++ name ++ "/"
62   mapM (writeVHDL dir) vhdl
63   return ()
64
65 -- | Show the core structure of the given binds in the given file.
66 listBind :: String -> String -> IO ()
67 listBind filename name = do
68   core <- loadModule filename
69   let [(b, expr)] = findBinds core [name]
70   putStr "\n"
71   putStr $ prettyShow expr
72   putStr "\n\n"
73   putStr $ showSDoc $ ppr expr
74   putStr "\n\n"
75   putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
76   putStr "\n\n"
77
78 -- | Translate the binds with the given names from the given core module to
79 --   VHDL. The Bool in the tuple makes the function stateful (True) or
80 --   stateless (False).
81 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
82 moduleToVHDL core list = do
83   let (names, statefuls) = unzip list
84   --liftIO $ putStr $ prettyShow (cm_binds core)
85   let binds = findBinds core names
86   --putStr $ prettyShow binds
87   -- Turn bind into VHDL
88   let (vhdl, sess) = State.runState (mkVHDL binds statefuls) (TranslatorSession core 0 Map.empty)
89   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr . snd) vhdl
90   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
91   return vhdl
92   where
93     -- Turns the given bind into VHDL
94     mkVHDL :: [(CoreBndr, CoreExpr)] -> [Bool] -> TranslatorState [(AST.VHDLId, AST.DesignFile)]
95     mkVHDL binds statefuls = do
96       -- Add the builtin functions
97       --mapM addBuiltIn builtin_funcs
98       -- Create entities and architectures for them
99       Monad.zipWithM processBind statefuls binds
100       modA tsFlatFuncs (Map.map nameFlatFunction)
101       flatfuncs <- getA tsFlatFuncs
102       return $ VHDL.createDesignFiles flatfuncs
103
104 -- | Write the given design file to a file with the given name inside the
105 --   given dir
106 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
107 writeVHDL dir (name, vhdl) = do
108   -- Create the dir if needed
109   exists <- Directory.doesDirectoryExist dir
110   Monad.unless exists $ Directory.createDirectory dir
111   -- Find the filename
112   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
113   -- Write the file
114   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
115
116 -- | Loads the given file and turns it into a core module.
117 loadModule :: String -> IO HscTypes.CoreModule
118 loadModule filename =
119   defaultErrorHandler defaultDynFlags $ do
120     runGhc (Just libdir) $ do
121       dflags <- getSessionDynFlags
122       setSessionDynFlags dflags
123       --target <- guessTarget "adder.hs" Nothing
124       --liftIO (print (showSDoc (ppr (target))))
125       --liftIO $ printTarget target
126       --setTargets [target]
127       --load LoadAllTargets
128       --core <- GHC.compileToCoreSimplified "Adders.hs"
129       core <- GHC.compileToCoreSimplified filename
130       return core
131
132 -- | Extracts the named binds from the given module.
133 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
134 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
135
136 -- | Extract a named bind from the given list of binds
137 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
138 findBind binds lookfor =
139   -- This ignores Recs and compares the name of the bind with lookfor,
140   -- disregarding any namespaces in OccName and extra attributes in Name and
141   -- Var.
142   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
143
144 -- | Processes the given bind as a top level bind.
145 processBind ::
146   Bool                       -- ^ Should this be stateful function?
147   -> (CoreBndr, CoreExpr)    -- ^ The bind to process
148   -> TranslatorState ()
149
150 processBind stateful bind@(var, expr) = do
151   -- Create the function signature
152   let ty = CoreUtils.exprType expr
153   let hsfunc = mkHsFunction var ty stateful
154   flattenBind hsfunc bind
155
156 -- | Flattens the given bind into the given signature and adds it to the
157 --   session. Then (recursively) finds any functions it uses and does the same
158 --   with them.
159 flattenBind ::
160   HsFunction                         -- The signature to flatten into
161   -> (CoreBndr, CoreExpr)            -- The bind to flatten
162   -> TranslatorState ()
163
164 flattenBind hsfunc bind@(var, expr) = do
165   -- Flatten the function
166   let flatfunc = flattenFunction hsfunc bind
167   -- Propagate state variables
168   let flatfunc' = propagateState hsfunc flatfunc
169   -- Store the flat function in the session
170   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
171   -- Flatten any functions used
172   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
173   mapM_ resolvFunc used_hsfuncs
174
175 -- | Decide which incoming state variables will become state in the
176 --   given function, and which will be propagate to other applied
177 --   functions.
178 propagateState ::
179   HsFunction
180   -> FlatFunction
181   -> FlatFunction
182
183 propagateState hsfunc flatfunc =
184     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
185   where
186     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
187     states' = zip olds news
188     -- Find all signals used by all sigdefs
189     uses = concatMap sigDefUses (flat_defs flatfunc)
190     -- Find all signals that are used more than once (is there a
191     -- prettier way to do this?)
192     multiple_uses = uses List.\\ (List.nub uses)
193     -- Find the states whose "old state" signal is used only once
194     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
195     -- See if these single use states can be propagated
196     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
197     substate_sigs = concat substate_sigss
198     -- Mark any propagated state signals as SigSubState
199     sigs' = map 
200       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
201       (flat_sigs flatfunc)
202
203 -- | Propagate the state into a single function application.
204 propagateState' ::
205   [(SignalId, SignalId)]
206                       -- ^ TODO
207   -> SigDef           -- ^ The SigDef to process.
208   -> ([SignalId], SigDef) 
209                       -- ^ Any signal ids that should become substates,
210                       --   and the resulting application.
211
212 propagateState' states def =
213     if (is_FApp def) then
214       (our_old ++ our_new, def {appFunc = hsfunc'})
215     else
216       ([], def)
217   where
218     hsfunc = appFunc def
219     args = appArgs def
220     res = appRes def
221     our_states = filter our_state states
222     -- A state signal belongs in this function if the old state is
223     -- passed in, and the new state returned
224     our_state (old, new) =
225       any (old `Foldable.elem`) args
226       && new `Foldable.elem` res
227     (our_old, our_new) = unzip our_states
228     -- Mark the result
229     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
230     res' = fmap (mark_state (zip our_new [0..])) zipped_res
231     -- Mark the args
232     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
233     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
234     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
235
236     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
237     mark_state states (id, use) =
238       case lookup id states of
239         Nothing -> use
240         Just state_id -> State state_id
241
242 -- | Returns pairs of signals that should be mapped to state in this function.
243 getStateSignals ::
244   HsFunction                      -- | The function to look at
245   -> FlatFunction                 -- | The function to look at
246   -> [(SignalId, SignalId)]   
247         -- | TODO The state signals. The first is the state number, the second the
248         --   signal to assign the current state to, the last is the signal
249         --   that holds the new state.
250
251 getStateSignals hsfunc flatfunc =
252   [(old_id, new_id) 
253     | (old_num, old_id) <- args
254     , (new_num, new_id) <- res
255     , old_num == new_num]
256   where
257     sigs = flat_sigs flatfunc
258     -- Translate args and res to lists of (statenum, sigid)
259     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
260     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
261     
262 -- | Find the given function, flatten it and add it to the session. Then
263 --   (recursively) do the same for any functions used.
264 resolvFunc ::
265   HsFunction        -- | The function to look for
266   -> TranslatorState ()
267
268 resolvFunc hsfunc = do
269   flatfuncmap <- getA tsFlatFuncs
270   -- Don't do anything if there is already a flat function for this hsfunc or
271   -- when it is a builtin function.
272   Monad.unless (Map.member hsfunc flatfuncmap) $ do
273   Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
274   -- New function, resolve it
275   core <- getA tsCoreModule
276   -- Find the named function
277   let name = (hsFuncName hsfunc)
278   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
279   case bind of
280     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
281     Just b  -> flattenBind hsfunc b
282
283 -- | Translate a top level function declaration to a HsFunction. i.e., which
284 --   interface will be provided by this function. This function essentially
285 --   defines the "calling convention" for hardware models.
286 mkHsFunction ::
287   Var.Var         -- ^ The function defined
288   -> Type         -- ^ The function type (including arguments!)
289   -> Bool         -- ^ Is this a stateful function?
290   -> HsFunction   -- ^ The resulting HsFunction
291
292 mkHsFunction f ty stateful=
293   HsFunction hsname hsargs hsres
294   where
295     hsname  = getOccString f
296     (arg_tys, res_ty) = Type.splitFunTys ty
297     (hsargs, hsres) = 
298       if stateful 
299       then
300         let
301           -- The last argument must be state
302           state_ty = last arg_tys
303           state    = useAsState (mkHsValueMap state_ty)
304           -- All but the last argument are inports
305           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
306           hsargs   = inports ++ [state]
307           hsres    = case splitTupleType res_ty of
308             -- Result type must be a two tuple (state, ports)
309             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
310               then
311                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
312               else
313                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
314             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
315         in
316           (hsargs, hsres)
317       else
318         -- Just use everything as a port
319         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
320
321 -- | Adds signal names to the given FlatFunction
322 nameFlatFunction ::
323   FlatFunction
324   -> FlatFunction
325
326 nameFlatFunction flatfunc =
327   -- Name the signals
328   let 
329     s = flat_sigs flatfunc
330     s' = map nameSignal s in
331   flatfunc { flat_sigs = s' }
332   where
333     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
334     nameSignal (id, info) =
335       let hints = nameHints info in
336       let parts = ("sig" : hints) ++ [show id] in
337       let name = concat $ List.intersperse "_" parts in
338       (id, info {sigName = Just name})
339
340 -- | Splits a tuple type into a list of element types, or Nothing if the type
341 --   is not a tuple type.
342 splitTupleType ::
343   Type              -- ^ The type to split
344   -> Maybe [Type]   -- ^ The tuples element types
345
346 splitTupleType ty =
347   case Type.splitTyConApp_maybe ty of
348     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
349       then
350         Just args
351       else
352         Nothing
353     Nothing -> Nothing
354
355 -- vim: set ts=8 sw=2 sts=2 expandtab: