We now support Annotations to indicate top-level entity and initial state
[matthijs/master-project/cλash.git] / cλash / CLasH / Translator.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2
3 module CLasH.Translator where
4
5 import qualified Directory
6 import qualified System.FilePath as FilePath
7 import qualified List
8 import Debug.Trace
9 import qualified Control.Arrow as Arrow
10 import GHC hiding (loadModule, sigName)
11 import CoreSyn
12 import qualified CoreUtils
13 import qualified Var
14 import qualified Type
15 import qualified TyCon
16 import qualified DataCon
17 import qualified HscMain
18 import qualified SrcLoc
19 import qualified FastString
20 import qualified Maybe
21 import qualified Module
22 import qualified Data.Foldable as Foldable
23 import qualified Control.Monad.Trans.State as State
24 import qualified Control.Monad as Monad
25 import Name
26 import qualified Data.Map as Map
27 import Data.Accessor
28 import Data.Generics
29 import NameEnv ( lookupNameEnv )
30 import qualified HscTypes
31 import HscTypes ( cm_binds, cm_types )
32 import MonadUtils ( liftIO )
33 import Outputable ( showSDoc, ppr, showSDocDebug )
34 import GHC.Paths ( libdir )
35 import DynFlags ( defaultDynFlags )
36 import qualified UniqSupply
37 import List ( find )
38 import qualified List
39 import qualified Monad
40 import qualified Annotations
41 import qualified Serialized
42
43 -- The following modules come from the ForSyDe project. They are really
44 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
45 -- ForSyDe to get access to these modules.
46 import qualified Language.VHDL.AST as AST
47 import qualified Language.VHDL.FileIO
48 import qualified Language.VHDL.Ppr as Ppr
49 -- This is needed for rendering the pretty printed VHDL
50 import Text.PrettyPrint.HughesPJ (render)
51
52 import CLasH.Translator.TranslatorTypes
53 import CLasH.Translator.Annotations
54 import CLasH.Utils.Pretty
55 import CLasH.Normalize
56 import CLasH.VHDL.VHDLTypes
57 import qualified CLasH.VHDL as VHDL
58
59 makeVHDL :: String -> String -> Bool -> IO ()
60 makeVHDL filename name stateful = do
61   -- Load the module
62   (core, env) <- loadModule filename
63   -- Translate to VHDL
64   vhdl <- moduleToVHDL env core [(name, stateful)]
65   -- Write VHDL to file
66   let dir = "./vhdl/" ++ name ++ "/"
67   prepareDir dir
68   mapM (writeVHDL dir) vhdl
69   return ()
70   
71 makeVHDLAnn :: String -> Bool -> IO ()
72 makeVHDLAnn filename stateful = do
73   (core, top, init, env) <- loadModuleAnn filename
74   let top_entity = head top
75   vhdl <- case init of 
76     [] -> moduleToVHDLAnn env core [top_entity]
77     xs -> moduleToVHDLAnnState env core [(top_entity, (head xs))]
78   let dir = "./vhdl/" ++ (show top_entity) ++ "/"
79   prepareDir dir
80   mapM (writeVHDL dir) vhdl
81   return ()
82
83 listBindings :: String -> IO [()]
84 listBindings filename = do
85   (core, env) <- loadModule filename
86   let binds = CoreSyn.flattenBinds $ cm_binds core
87   mapM (listBinding) binds
88
89 listBinding :: (CoreBndr, CoreExpr) -> IO ()
90 listBinding (b, e) = do
91   putStr "\nBinder: "
92   putStr $ show b
93   putStr "\nExpression: \n"
94   putStr $ prettyShow e
95   putStr "\n\n"
96   putStr $ showSDoc $ ppr e
97   putStr "\n\n"
98   putStr $ showSDoc $ ppr $ CoreUtils.exprType e
99   putStr "\n\n"
100   
101 -- | Show the core structure of the given binds in the given file.
102 listBind :: String -> String -> IO ()
103 listBind filename name = do
104   (core, env) <- loadModule filename
105   let [(b, expr)] = findBinds core [name]
106   putStr "\n"
107   putStr $ prettyShow expr
108   putStr "\n\n"
109   putStr $ showSDoc $ ppr expr
110   putStr "\n\n"
111   putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
112   putStr "\n\n"
113
114 -- | Translate the binds with the given names from the given core module to
115 --   VHDL. The Bool in the tuple makes the function stateful (True) or
116 --   stateless (False).
117 moduleToVHDL :: HscTypes.HscEnv -> HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
118 moduleToVHDL env core list = do
119   let (names, statefuls) = unzip list
120   let binds = map fst $ findBinds core names
121   -- Generate a UniqSupply
122   -- Running 
123   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
124   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
125   -- unique supply anywhere.
126   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
127   -- Turn bind into VHDL
128   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
129   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds statefuls
130   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
131   mapM (putStr . render . Ppr.ppr . snd) vhdl
132   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
133   return vhdl
134   
135 moduleToVHDLAnn :: HscTypes.HscEnv -> HscTypes.CoreModule -> [CoreSyn.CoreBndr] -> IO [(AST.VHDLId, AST.DesignFile)]
136 moduleToVHDLAnn env core binds = do
137   -- Generate a UniqSupply
138   -- Running 
139   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
140   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
141   -- unique supply anywhere.
142   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
143   -- Turn bind into VHDL
144   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
145   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds [False]
146   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
147   mapM (putStr . render . Ppr.ppr . snd) vhdl
148   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
149   return vhdl
150   
151 moduleToVHDLAnnState :: HscTypes.HscEnv -> HscTypes.CoreModule -> [(CoreSyn.CoreBndr, CoreSyn.CoreBndr)] -> IO [(AST.VHDLId, AST.DesignFile)]
152 moduleToVHDLAnnState env core list = do
153   let (binds, init_states) = unzip list
154   -- Generate a UniqSupply
155   -- Running 
156   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
157   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
158   -- unique supply anywhere.
159   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
160   -- Turn bind into VHDL
161   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
162   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds [True]
163   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
164   mapM (putStr . render . Ppr.ppr . snd) vhdl
165   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
166   return vhdl
167
168 -- | Prepares the directory for writing VHDL files. This means creating the
169 --   dir if it does not exist and removing all existing .vhdl files from it.
170 prepareDir :: String -> IO()
171 prepareDir dir = do
172   -- Create the dir if needed
173   exists <- Directory.doesDirectoryExist dir
174   Monad.unless exists $ Directory.createDirectory dir
175   -- Find all .vhdl files in the directory
176   files <- Directory.getDirectoryContents dir
177   let to_remove = filter ((==".vhdl") . FilePath.takeExtension) files
178   -- Prepend the dirname to the filenames
179   let abs_to_remove = map (FilePath.combine dir) to_remove
180   -- Remove the files
181   mapM_ Directory.removeFile abs_to_remove
182
183 -- | Write the given design file to a file with the given name inside the
184 --   given dir
185 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
186 writeVHDL dir (name, vhdl) = do
187   -- Find the filename
188   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
189   -- Write the file
190   Language.VHDL.FileIO.writeDesignFile vhdl fname
191
192 -- | Loads the given file and turns it into a core module.
193 loadModule :: String -> IO (HscTypes.CoreModule, HscTypes.HscEnv)
194 loadModule filename =
195   defaultErrorHandler defaultDynFlags $ do
196     runGhc (Just libdir) $ do
197       dflags <- getSessionDynFlags
198       setSessionDynFlags dflags
199       --target <- guessTarget "adder.hs" Nothing
200       --liftIO (print (showSDoc (ppr (target))))
201       --liftIO $ printTarget target
202       --setTargets [target]
203       --load LoadAllTargets
204       --core <- GHC.compileToCoreSimplified "Adders.hs"
205       core <- GHC.compileToCoreModule filename
206       env <- GHC.getSession
207       return (core, env)
208       
209 -- | Loads the given file and turns it into a core module.
210 loadModuleAnn :: String -> IO (HscTypes.CoreModule, [CoreSyn.CoreBndr], [CoreSyn.CoreBndr], HscTypes.HscEnv)
211 loadModuleAnn filename =
212   defaultErrorHandler defaultDynFlags $ do
213     runGhc (Just libdir) $ do
214       dflags <- getSessionDynFlags
215       setSessionDynFlags dflags
216       --target <- guessTarget "adder.hs" Nothing
217       --liftIO (print (showSDoc (ppr (target))))
218       --liftIO $ printTarget target
219       --setTargets [target]
220       --load LoadAllTargets
221       --core <- GHC.compileToCoreSimplified "Adders.hs"
222       core <- GHC.compileToCoreModule filename
223       env <- GHC.getSession
224       top_entity <- findTopEntity core
225       init_state <- findInitState core
226       return (core, top_entity, init_state, env)
227
228 findTopEntity :: GhcMonad m => HscTypes.CoreModule -> m [CoreSyn.CoreBndr]
229 findTopEntity core = do
230   let binds = CoreSyn.flattenBinds $ cm_binds core
231   topbinds <- Monad.filterM (hasTopEntityAnnotation . fst) binds
232   let bndrs = case topbinds of [] -> error $ "Couldn't find top entity in current module." ; xs -> fst (unzip topbinds)
233   return bndrs
234   
235 findInitState :: GhcMonad m => HscTypes.CoreModule -> m [CoreSyn.CoreBndr]
236 findInitState core = do
237   let binds = CoreSyn.flattenBinds $ cm_binds core
238   statebinds <- Monad.filterM (hasInitStateAnnotation . fst) binds
239   let bndrs = case statebinds of [] -> [] ; xs -> fst (unzip statebinds)
240   return bndrs
241   
242 hasTopEntityAnnotation :: GhcMonad m => Var.Var -> m Bool
243 hasTopEntityAnnotation var = do
244   let deserializer = Serialized.deserializeWithData
245   let target = Annotations.NamedTarget (Var.varName var)
246   (anns :: [CLasHAnn]) <- GHC.findGlobalAnns deserializer target
247   let top_ents = filter isTopEntity anns
248   case top_ents of
249     [] -> return False
250     xs -> return True
251     
252 hasInitStateAnnotation :: GhcMonad m => Var.Var -> m Bool
253 hasInitStateAnnotation var = do
254   let deserializer = Serialized.deserializeWithData
255   let target = Annotations.NamedTarget (Var.varName var)
256   (anns :: [CLasHAnn]) <- GHC.findGlobalAnns deserializer target
257   let top_ents = filter isInitState anns
258   case top_ents of
259     [] -> return False
260     xs -> return True
261
262 -- | Extracts the named binds from the given module.
263 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
264 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
265
266 -- | Extract a named bind from the given list of binds
267 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
268 findBind binds lookfor =
269   -- This ignores Recs and compares the name of the bind with lookfor,
270   -- disregarding any namespaces in OccName and extra attributes in Name and
271   -- Var.
272   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
273
274 -- | Flattens the given bind into the given signature and adds it to the
275 --   session. Then (recursively) finds any functions it uses and does the same
276 --   with them.
277 -- flattenBind ::
278 --   HsFunction                         -- The signature to flatten into
279 --   -> (CoreBndr, CoreExpr)            -- The bind to flatten
280 --   -> TranslatorState ()
281 -- 
282 -- flattenBind hsfunc bind@(var, expr) = do
283 --   -- Flatten the function
284 --   let flatfunc = flattenFunction hsfunc bind
285 --   -- Propagate state variables
286 --   let flatfunc' = propagateState hsfunc flatfunc
287 --   -- Store the flat function in the session
288 --   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
289 --   -- Flatten any functions used
290 --   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
291 --   mapM_ resolvFunc used_hsfuncs
292
293 -- | Decide which incoming state variables will become state in the
294 --   given function, and which will be propagate to other applied
295 --   functions.
296 -- propagateState ::
297 --   HsFunction
298 --   -> FlatFunction
299 --   -> FlatFunction
300 -- 
301 -- propagateState hsfunc flatfunc =
302 --     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
303 --   where
304 --     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
305 --     states' = zip olds news
306 --     -- Find all signals used by all sigdefs
307 --     uses = concatMap sigDefUses (flat_defs flatfunc)
308 --     -- Find all signals that are used more than once (is there a
309 --     -- prettier way to do this?)
310 --     multiple_uses = uses List.\\ (List.nub uses)
311 --     -- Find the states whose "old state" signal is used only once
312 --     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
313 --     -- See if these single use states can be propagated
314 --     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
315 --     substate_sigs = concat substate_sigss
316 --     -- Mark any propagated state signals as SigSubState
317 --     sigs' = map 
318 --       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
319 --       (flat_sigs flatfunc)
320
321 -- | Propagate the state into a single function application.
322 -- propagateState' ::
323 --   [(SignalId, SignalId)]
324 --                       -- ^ TODO
325 --   -> SigDef           -- ^ The SigDef to process.
326 --   -> ([SignalId], SigDef) 
327 --                       -- ^ Any signal ids that should become substates,
328 --                       --   and the resulting application.
329 -- 
330 -- propagateState' states def =
331 --     if (is_FApp def) then
332 --       (our_old ++ our_new, def {appFunc = hsfunc'})
333 --     else
334 --       ([], def)
335 --   where
336 --     hsfunc = appFunc def
337 --     args = appArgs def
338 --     res = appRes def
339 --     our_states = filter our_state states
340 --     -- A state signal belongs in this function if the old state is
341 --     -- passed in, and the new state returned
342 --     our_state (old, new) =
343 --       any (old `Foldable.elem`) args
344 --       && new `Foldable.elem` res
345 --     (our_old, our_new) = unzip our_states
346 --     -- Mark the result
347 --     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
348 --     res' = fmap (mark_state (zip our_new [0..])) zipped_res
349 --     -- Mark the args
350 --     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
351 --     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
352 --     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
353 -- 
354 --     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
355 --     mark_state states (id, use) =
356 --       case lookup id states of
357 --         Nothing -> use
358 --         Just state_id -> State state_id
359
360 -- | Returns pairs of signals that should be mapped to state in this function.
361 -- getStateSignals ::
362 --   HsFunction                      -- | The function to look at
363 --   -> FlatFunction                 -- | The function to look at
364 --   -> [(SignalId, SignalId)]   
365 --         -- | TODO The state signals. The first is the state number, the second the
366 --         --   signal to assign the current state to, the last is the signal
367 --         --   that holds the new state.
368 -- 
369 -- getStateSignals hsfunc flatfunc =
370 --   [(old_id, new_id) 
371 --     | (old_num, old_id) <- args
372 --     , (new_num, new_id) <- res
373 --     , old_num == new_num]
374 --   where
375 --     sigs = flat_sigs flatfunc
376 --     -- Translate args and res to lists of (statenum, sigid)
377 --     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
378 --     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
379     
380 -- | Find the given function, flatten it and add it to the session. Then
381 --   (recursively) do the same for any functions used.
382 -- resolvFunc ::
383 --   HsFunction        -- | The function to look for
384 --   -> TranslatorState ()
385 -- 
386 -- resolvFunc hsfunc = do
387 --   flatfuncmap <- getA tsFlatFuncs
388 --   -- Don't do anything if there is already a flat function for this hsfunc or
389 --   -- when it is a builtin function.
390 --   Monad.unless (Map.member hsfunc flatfuncmap) $ do
391 --   -- Not working with new builtins -- Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
392 --   -- New function, resolve it
393 --   core <- getA tsCoreModule
394 --   -- Find the named function
395 --   let name = (hsFuncName hsfunc)
396 --   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
397 --   case bind of
398 --     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
399 --     Just b  -> flattenBind hsfunc b
400
401 -- | Translate a top level function declaration to a HsFunction. i.e., which
402 --   interface will be provided by this function. This function essentially
403 --   defines the "calling convention" for hardware models.
404 -- mkHsFunction ::
405 --   Var.Var         -- ^ The function defined
406 --   -> Type         -- ^ The function type (including arguments!)
407 --   -> Bool         -- ^ Is this a stateful function?
408 --   -> HsFunction   -- ^ The resulting HsFunction
409 -- 
410 -- mkHsFunction f ty stateful=
411 --   HsFunction hsname hsargs hsres
412 --   where
413 --     hsname  = getOccString f
414 --     (arg_tys, res_ty) = Type.splitFunTys ty
415 --     (hsargs, hsres) = 
416 --       if stateful 
417 --       then
418 --         let
419 --           -- The last argument must be state
420 --           state_ty = last arg_tys
421 --           state    = useAsState (mkHsValueMap state_ty)
422 --           -- All but the last argument are inports
423 --           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
424 --           hsargs   = inports ++ [state]
425 --           hsres    = case splitTupleType res_ty of
426 --             -- Result type must be a two tuple (state, ports)
427 --             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
428 --               then
429 --                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
430 --               else
431 --                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
432 --             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
433 --         in
434 --           (hsargs, hsres)
435 --       else
436 --         -- Just use everything as a port
437 --         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
438
439 -- | Adds signal names to the given FlatFunction
440 -- nameFlatFunction ::
441 --   FlatFunction
442 --   -> FlatFunction
443 -- 
444 -- nameFlatFunction flatfunc =
445 --   -- Name the signals
446 --   let 
447 --     s = flat_sigs flatfunc
448 --     s' = map nameSignal s in
449 --   flatfunc { flat_sigs = s' }
450 --   where
451 --     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
452 --     nameSignal (id, info) =
453 --       let hints = nameHints info in
454 --       let parts = ("sig" : hints) ++ [show id] in
455 --       let name = concat $ List.intersperse "_" parts in
456 --       (id, info {sigName = Just name})
457 -- 
458 -- -- | Splits a tuple type into a list of element types, or Nothing if the type
459 -- --   is not a tuple type.
460 -- splitTupleType ::
461 --   Type              -- ^ The type to split
462 --   -> Maybe [Type]   -- ^ The tuples element types
463 -- 
464 -- splitTupleType ty =
465 --   case Type.splitTyConApp_maybe ty of
466 --     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
467 --       then
468 --         Just args
469 --       else
470 --         Nothing
471 --     Nothing -> Nothing
472
473 -- vim: set ts=8 sw=2 sts=2 expandtab: