Redo the global (state) structure of the translator.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified List
4 import GHC hiding (loadModule, sigName)
5 import CoreSyn
6 import qualified CoreUtils
7 import qualified Var
8 import qualified Type
9 import qualified TyCon
10 import qualified DataCon
11 import qualified Maybe
12 import qualified Module
13 import qualified Data.Foldable as Foldable
14 import qualified Control.Monad.Trans.State as State
15 import Name
16 import qualified Data.Map as Map
17 import Data.Accessor
18 import Data.Generics
19 import NameEnv ( lookupNameEnv )
20 import qualified HscTypes
21 import HscTypes ( cm_binds, cm_types )
22 import MonadUtils ( liftIO )
23 import Outputable ( showSDoc, ppr )
24 import GHC.Paths ( libdir )
25 import DynFlags ( defaultDynFlags )
26 import List ( find )
27 import qualified List
28 import qualified Monad
29
30 -- The following modules come from the ForSyDe project. They are really
31 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
32 -- ForSyDe to get access to these modules.
33 import qualified ForSyDe.Backend.VHDL.AST as AST
34 import qualified ForSyDe.Backend.VHDL.Ppr
35 import qualified ForSyDe.Backend.VHDL.FileIO
36 import qualified ForSyDe.Backend.Ppr
37 -- This is needed for rendering the pretty printed VHDL
38 import Text.PrettyPrint.HughesPJ (render)
39
40 import TranslatorTypes
41 import HsValueMap
42 import Pretty
43 import Flatten
44 import FlattenTypes
45 import VHDLTypes
46 import qualified VHDL
47
48 main = do
49   makeVHDL "Alu.hs" "register_bank" True
50
51 makeVHDL :: String -> String -> Bool -> IO ()
52 makeVHDL filename name stateful = do
53   -- Load the module
54   core <- loadModule filename
55   -- Translate to VHDL
56   vhdl <- moduleToVHDL core [(name, stateful)]
57   -- Write VHDL to file
58   let dir = "../vhdl/vhdl/" ++ name ++ "/"
59   mapM (writeVHDL dir) vhdl
60   return ()
61
62 -- | Show the core structure of the given binds in the given file.
63 listBind :: String -> String -> IO ()
64 listBind filename name = do
65   core <- loadModule filename
66   let binds = findBinds core [name]
67   putStr "\n"
68   putStr $ prettyShow binds
69   putStr "\n\n"
70   putStr $ showSDoc $ ppr binds
71   putStr "\n\n"
72
73 -- | Translate the binds with the given names from the given core module to
74 --   VHDL. The Bool in the tuple makes the function stateful (True) or
75 --   stateless (False).
76 moduleToVHDL :: HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
77 moduleToVHDL core list = do
78   let (names, statefuls) = unzip list
79   --liftIO $ putStr $ prettyShow (cm_binds core)
80   let binds = findBinds core names
81   --putStr $ prettyShow binds
82   -- Turn bind into VHDL
83   let (vhdl, sess) = State.runState (mkVHDL binds statefuls) (TranslatorSession core 0 Map.empty)
84   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr . snd) vhdl
85   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
86   return vhdl
87   where
88     -- Turns the given bind into VHDL
89     mkVHDL :: [CoreBind] -> [Bool] -> TranslatorState [(AST.VHDLId, AST.DesignFile)]
90     mkVHDL binds statefuls = do
91       -- Add the builtin functions
92       --mapM addBuiltIn builtin_funcs
93       -- Create entities and architectures for them
94       Monad.zipWithM processBind statefuls binds
95       modA tsFlatFuncs (Map.map nameFlatFunction)
96       flatfuncs <- getA tsFlatFuncs
97       return $ VHDL.createDesignFiles flatfuncs
98
99 -- | Write the given design file to a file with the given name inside the
100 --   given dir
101 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
102 writeVHDL dir (name, vhdl) = do
103   -- Create the dir if needed
104   exists <- Directory.doesDirectoryExist dir
105   Monad.unless exists $ Directory.createDirectory dir
106   -- Find the filename
107   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
108   -- Write the file
109   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
110
111 -- | Loads the given file and turns it into a core module.
112 loadModule :: String -> IO HscTypes.CoreModule
113 loadModule filename =
114   defaultErrorHandler defaultDynFlags $ do
115     runGhc (Just libdir) $ do
116       dflags <- getSessionDynFlags
117       setSessionDynFlags dflags
118       --target <- guessTarget "adder.hs" Nothing
119       --liftIO (print (showSDoc (ppr (target))))
120       --liftIO $ printTarget target
121       --setTargets [target]
122       --load LoadAllTargets
123       --core <- GHC.compileToCoreSimplified "Adders.hs"
124       core <- GHC.compileToCoreSimplified filename
125       return core
126
127 -- | Extracts the named binds from the given module.
128 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
129 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
130
131 -- | Extract a named bind from the given list of binds
132 findBind :: [CoreBind] -> String -> Maybe CoreBind
133 findBind binds lookfor =
134   -- This ignores Recs and compares the name of the bind with lookfor,
135   -- disregarding any namespaces in OccName and extra attributes in Name and
136   -- Var.
137   find (\b -> case b of 
138     Rec l -> False
139     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
140   ) binds
141
142 -- | Processes the given bind as a top level bind.
143 processBind ::
144   Bool                       -- ^ Should this be stateful function?
145   -> CoreBind                -- ^ The bind to process
146   -> TranslatorState ()
147
148 processBind _ (Rec _) = error "Recursive binders not supported"
149 processBind stateful bind@(NonRec var expr) = do
150   -- Create the function signature
151   let ty = CoreUtils.exprType expr
152   let hsfunc = mkHsFunction var ty stateful
153   flattenBind hsfunc bind
154
155 -- | Flattens the given bind into the given signature and adds it to the
156 --   session. Then (recursively) finds any functions it uses and does the same
157 --   with them.
158 flattenBind ::
159   HsFunction                         -- The signature to flatten into
160   -> CoreBind                        -- The bind to flatten
161   -> TranslatorState ()
162
163 flattenBind _ (Rec _) = error "Recursive binders not supported"
164
165 flattenBind hsfunc bind@(NonRec var expr) = do
166   -- Flatten the function
167   let flatfunc = flattenFunction hsfunc bind
168   -- Propagate state variables
169   let flatfunc' = propagateState hsfunc flatfunc
170   -- Store the flat function in the session
171   modA tsFlatFuncs (Map.insert hsfunc flatfunc)
172   -- Flatten any functions used
173   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
174   mapM_ resolvFunc used_hsfuncs
175
176 -- | Decide which incoming state variables will become state in the
177 --   given function, and which will be propagate to other applied
178 --   functions.
179 propagateState ::
180   HsFunction
181   -> FlatFunction
182   -> FlatFunction
183
184 propagateState hsfunc flatfunc =
185     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
186   where
187     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
188     states' = zip olds news
189     -- Find all signals used by all sigdefs
190     uses = concatMap sigDefUses (flat_defs flatfunc)
191     -- Find all signals that are used more than once (is there a
192     -- prettier way to do this?)
193     multiple_uses = uses List.\\ (List.nub uses)
194     -- Find the states whose "old state" signal is used only once
195     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
196     -- See if these single use states can be propagated
197     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
198     substate_sigs = concat substate_sigss
199     -- Mark any propagated state signals as SigSubState
200     sigs' = map 
201       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
202       (flat_sigs flatfunc)
203
204 -- | Propagate the state into a single function application.
205 propagateState' ::
206   [(SignalId, SignalId)]
207                       -- ^ TODO
208   -> SigDef           -- ^ The SigDef to process.
209   -> ([SignalId], SigDef) 
210                       -- ^ Any signal ids that should become substates,
211                       --   and the resulting application.
212
213 propagateState' states def =
214     if (is_FApp def) then
215       (our_old ++ our_new, def {appFunc = hsfunc'})
216     else
217       ([], def)
218   where
219     hsfunc = appFunc def
220     args = appArgs def
221     res = appRes def
222     our_states = filter our_state states
223     -- A state signal belongs in this function if the old state is
224     -- passed in, and the new state returned
225     our_state (old, new) =
226       any (old `Foldable.elem`) args
227       && new `Foldable.elem` res
228     (our_old, our_new) = unzip our_states
229     -- Mark the result
230     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
231     res' = fmap (mark_state (zip our_new [0..])) zipped_res
232     -- Mark the args
233     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
234     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
235     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
236
237     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
238     mark_state states (id, use) =
239       case lookup id states of
240         Nothing -> use
241         Just state_id -> State state_id
242
243 -- | Returns pairs of signals that should be mapped to state in this function.
244 getStateSignals ::
245   HsFunction                      -- | The function to look at
246   -> FlatFunction                 -- | The function to look at
247   -> [(SignalId, SignalId)]   
248         -- | TODO The state signals. The first is the state number, the second the
249         --   signal to assign the current state to, the last is the signal
250         --   that holds the new state.
251
252 getStateSignals hsfunc flatfunc =
253   [(old_id, new_id) 
254     | (old_num, old_id) <- args
255     , (new_num, new_id) <- res
256     , old_num == new_num]
257   where
258     sigs = flat_sigs flatfunc
259     -- Translate args and res to lists of (statenum, sigid)
260     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
261     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
262     
263 -- | Find the given function, flatten it and add it to the session. Then
264 --   (recursively) do the same for any functions used.
265 resolvFunc ::
266   HsFunction        -- | The function to look for
267   -> TranslatorState ()
268
269 resolvFunc hsfunc = do
270   flatfuncmap <- getA tsFlatFuncs
271   -- Don't do anything if there is already a flat function for this hsfunc.
272   Monad.unless (Map.member hsfunc flatfuncmap) $ do
273   -- TODO: Builtin functions
274   -- New function, resolve it
275   core <- getA tsCoreModule
276   -- Find the named function
277   let name = (hsFuncName hsfunc)
278   let bind = findBind (cm_binds core) name 
279   case bind of
280     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
281     Just b  -> flattenBind hsfunc b
282
283 -- | Translate a top level function declaration to a HsFunction. i.e., which
284 --   interface will be provided by this function. This function essentially
285 --   defines the "calling convention" for hardware models.
286 mkHsFunction ::
287   Var.Var         -- ^ The function defined
288   -> Type         -- ^ The function type (including arguments!)
289   -> Bool         -- ^ Is this a stateful function?
290   -> HsFunction   -- ^ The resulting HsFunction
291
292 mkHsFunction f ty stateful=
293   HsFunction hsname hsargs hsres
294   where
295     hsname  = getOccString f
296     (arg_tys, res_ty) = Type.splitFunTys ty
297     (hsargs, hsres) = 
298       if stateful 
299       then
300         let
301           -- The last argument must be state
302           state_ty = last arg_tys
303           state    = useAsState (mkHsValueMap state_ty)
304           -- All but the last argument are inports
305           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
306           hsargs   = inports ++ [state]
307           hsres    = case splitTupleType res_ty of
308             -- Result type must be a two tuple (state, ports)
309             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
310               then
311                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
312               else
313                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
314             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
315         in
316           (hsargs, hsres)
317       else
318         -- Just use everything as a port
319         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
320
321 -- | Adds signal names to the given FlatFunction
322 nameFlatFunction ::
323   FlatFunction
324   -> FlatFunction
325
326 nameFlatFunction flatfunc =
327   -- Name the signals
328   let 
329     s = flat_sigs flatfunc
330     s' = map nameSignal s in
331   flatfunc { flat_sigs = s' }
332   where
333     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
334     nameSignal (id, info) =
335       let hints = nameHints info in
336       let parts = ("sig" : hints) ++ [show id] in
337       let name = concat $ List.intersperse "_" parts in
338       (id, info {sigName = Just name})
339
340 -- | Splits a tuple type into a list of element types, or Nothing if the type
341 --   is not a tuple type.
342 splitTupleType ::
343   Type              -- ^ The type to split
344   -> Maybe [Type]   -- ^ The tuples element types
345
346 splitTupleType ty =
347   case Type.splitTyConApp_maybe ty of
348     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
349       then
350         Just args
351       else
352         Nothing
353     Nothing -> Nothing
354
355 -- | A consise representation of a (set of) ports on a builtin function
356 type PortMap = HsValueMap (String, AST.TypeMark)
357 -- | A consise representation of a builtin function
358 data BuiltIn = BuiltIn String [PortMap] PortMap
359
360 -- | Map a port specification of a builtin function to a VHDL Signal to put in
361 --   a VHDLSignalMap
362 toVHDLSignalMap :: HsValueMap (String, AST.TypeMark) -> VHDLSignalMap
363 toVHDLSignalMap = fmap (\(name, ty) -> Just (VHDL.mkVHDLId name, ty))
364
365 -- | Translate a concise representation of a builtin function to something
366 --   that can be put into FuncMap directly.
367 {-
368 addBuiltIn :: BuiltIn -> TranslatorState ()
369 addBuiltIn (BuiltIn name args res) = do
370     addFunc hsfunc
371     setEntity hsfunc entity
372   where
373     hsfunc = HsFunction name (map useAsPort args) (useAsPort res)
374     entity = Entity (VHDL.mkVHDLId name) (map toVHDLSignalMap args) (toVHDLSignalMap res) Nothing Nothing
375
376 builtin_funcs = 
377   [ 
378     BuiltIn "hwxor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
379     BuiltIn "hwand" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
380     BuiltIn "hwor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
381     BuiltIn "hwnot" [(Single ("a", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty))
382   ]
383 -}
384 -- vim: set ts=8 sw=2 sts=2 expandtab: