Add automated testbench generation according to supplied test input
[matthijs/master-project/cλash.git] / cλash / CLasH / Translator.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2
3 module CLasH.Translator where
4
5 import qualified Directory
6 import qualified System.FilePath as FilePath
7 import qualified List
8 import Debug.Trace
9 import qualified Control.Arrow as Arrow
10 import GHC hiding (loadModule, sigName)
11 import CoreSyn
12 import qualified CoreUtils
13 import qualified Var
14 import qualified Type
15 import qualified TyCon
16 import qualified DataCon
17 import qualified HscMain
18 import qualified SrcLoc
19 import qualified FastString
20 import qualified Maybe
21 import qualified Module
22 import qualified Data.Foldable as Foldable
23 import qualified Control.Monad.Trans.State as State
24 import qualified Control.Monad as Monad
25 import Name
26 import qualified Data.Map as Map
27 import Data.Accessor
28 import Data.Generics
29 import NameEnv ( lookupNameEnv )
30 import qualified HscTypes
31 import HscTypes ( cm_binds, cm_types )
32 import MonadUtils ( liftIO )
33 import Outputable ( showSDoc, ppr, showSDocDebug )
34 import DynFlags ( defaultDynFlags )
35 import qualified UniqSupply
36 import List ( find )
37 import qualified List
38 import qualified Monad
39 import qualified Annotations
40 import qualified Serialized
41
42 -- The following modules come from the ForSyDe project. They are really
43 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
44 -- ForSyDe to get access to these modules.
45 import qualified Language.VHDL.AST as AST
46 import qualified Language.VHDL.FileIO
47 import qualified Language.VHDL.Ppr as Ppr
48 -- This is needed for rendering the pretty printed VHDL
49 import Text.PrettyPrint.HughesPJ (render)
50
51 import CLasH.Translator.TranslatorTypes
52 import CLasH.Translator.Annotations
53 import CLasH.Utils.Pretty
54 import CLasH.Normalize
55 import CLasH.VHDL.VHDLTypes
56 import CLasH.Utils.Core.CoreTools
57 import qualified CLasH.VHDL as VHDL
58
59 -- makeVHDL :: FilePath -> String -> String -> Bool -> IO ()
60 -- makeVHDL libdir filename name stateful = do
61 --   -- Load the module
62 --   (core, env) <- loadModule libdir filename
63 --   -- Translate to VHDL
64 --   vhdl <- moduleToVHDL env core [(name, stateful)]
65 --   -- Write VHDL to file
66 --   let dir = "./vhdl/" ++ name ++ "/"
67 --   prepareDir dir
68 --   mapM (writeVHDL dir) vhdl
69 --   return ()
70   
71 makeVHDLAnn :: FilePath -> String -> IO ()
72 makeVHDLAnn libdir filename = do
73   (core, top, init, test, env) <- loadModuleAnn libdir filename
74   let top_entity = head top
75   let test_expr = head test
76   vhdl <- case init of 
77     [] -> moduleToVHDLAnn env core (top_entity, test_expr)
78     xs -> moduleToVHDLAnnState env core (top_entity, test_expr, (head xs))
79   let dir = "./vhdl/" ++ (show top_entity) ++ "/"
80   prepareDir dir
81   mapM (writeVHDL dir) vhdl
82   return ()
83
84 listBindings :: FilePath -> String -> IO [()]
85 listBindings libdir filename = do
86   (core, env) <- loadModule libdir filename
87   let binds = CoreSyn.flattenBinds $ cm_binds core
88   mapM (listBinding) binds
89
90 listBinding :: (CoreBndr, CoreExpr) -> IO ()
91 listBinding (b, e) = do
92   putStr "\nBinder: "
93   putStr $ show b
94   putStr "\nExpression: \n"
95   putStr $ prettyShow e
96   putStr "\n\n"
97   putStr $ showSDoc $ ppr e
98   putStr "\n\n"
99   putStr $ showSDoc $ ppr $ CoreUtils.exprType e
100   putStr "\n\n"
101   
102 -- | Show the core structure of the given binds in the given file.
103 listBind :: FilePath -> String -> String -> IO ()
104 listBind libdir filename name = do
105   (core, env) <- loadModule libdir filename
106   let [(b, expr)] = findBinds core [name]
107   putStr "\n"
108   putStr $ prettyShow expr
109   putStr "\n\n"
110   putStr $ showSDoc $ ppr expr
111   putStr "\n\n"
112   putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
113   putStr "\n\n"  
114
115 -- | Translate the binds with the given names from the given core module to
116 --   VHDL. The Bool in the tuple makes the function stateful (True) or
117 --   stateless (False).
118 -- moduleToVHDL :: HscTypes.HscEnv -> HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
119 -- moduleToVHDL env core list = do
120 --   let (names, statefuls) = unzip list
121 --   let binds = map fst $ findBinds core names
122 --   -- Generate a UniqSupply
123 --   -- Running 
124 --   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
125 --   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
126 --   -- unique supply anywhere.
127 --   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
128 --   -- Turn bind into VHDL
129 --   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
130 --   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds statefuls
131 --   let vhdl = VHDL.createDesignFiles typestate normalized_bindings binds
132 --   mapM (putStr . render . Ppr.ppr . snd) vhdl
133 --   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
134 --   return vhdl
135   
136 moduleToVHDLAnn :: HscTypes.HscEnv -> HscTypes.CoreModule -> (CoreSyn.CoreBndr, CoreSyn.CoreExpr) -> IO [(AST.VHDLId, AST.DesignFile)]
137 moduleToVHDLAnn env core (topbind, test) = do
138   -- Generate a UniqSupply
139   -- Running 
140   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
141   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
142   -- unique supply anywhere.
143   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
144   -- Turn bind into VHDL
145   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
146   let testexprs = reduceCoreListToHsList test
147   let (normalized_bindings, test_bindings, typestate) = normalizeModule env uniqSupply all_bindings testexprs [topbind] [False]
148   let vhdl = VHDL.createDesignFiles typestate normalized_bindings topbind test_bindings
149   mapM (putStr . render . Ppr.ppr . snd) vhdl
150   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
151   return vhdl
152   
153 moduleToVHDLAnnState :: HscTypes.HscEnv -> HscTypes.CoreModule -> (CoreSyn.CoreBndr, CoreSyn.CoreExpr, CoreSyn.CoreBndr) -> IO [(AST.VHDLId, AST.DesignFile)]
154 moduleToVHDLAnnState env core (topbind, test, init_state) = do
155   -- Generate a UniqSupply
156   -- Running 
157   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
158   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
159   -- unique supply anywhere.
160   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
161   -- Turn bind into VHDL
162   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
163   let testexprs = reduceCoreListToHsList test
164   let (normalized_bindings, test_bindings, typestate) = normalizeModule env uniqSupply all_bindings testexprs [topbind] [True]
165   let vhdl = VHDL.createDesignFiles typestate normalized_bindings topbind test_bindings
166   mapM (putStr . render . Ppr.ppr . snd) vhdl
167   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
168   return vhdl
169
170 -- | Prepares the directory for writing VHDL files. This means creating the
171 --   dir if it does not exist and removing all existing .vhdl files from it.
172 prepareDir :: String -> IO()
173 prepareDir dir = do
174   -- Create the dir if needed
175   exists <- Directory.doesDirectoryExist dir
176   Monad.unless exists $ Directory.createDirectory dir
177   -- Find all .vhdl files in the directory
178   files <- Directory.getDirectoryContents dir
179   let to_remove = filter ((==".vhdl") . FilePath.takeExtension) files
180   -- Prepend the dirname to the filenames
181   let abs_to_remove = map (FilePath.combine dir) to_remove
182   -- Remove the files
183   mapM_ Directory.removeFile abs_to_remove
184
185 -- | Write the given design file to a file with the given name inside the
186 --   given dir
187 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
188 writeVHDL dir (name, vhdl) = do
189   -- Find the filename
190   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
191   -- Write the file
192   Language.VHDL.FileIO.writeDesignFile vhdl fname
193
194 -- | Loads the given file and turns it into a core module.
195 loadModule :: FilePath -> String -> IO (HscTypes.CoreModule, HscTypes.HscEnv)
196 loadModule libdir filename =
197   defaultErrorHandler defaultDynFlags $ do
198     runGhc (Just libdir) $ do
199       dflags <- getSessionDynFlags
200       setSessionDynFlags dflags
201       --target <- guessTarget "adder.hs" Nothing
202       --liftIO (print (showSDoc (ppr (target))))
203       --liftIO $ printTarget target
204       --setTargets [target]
205       --load LoadAllTargets
206       --core <- GHC.compileToCoreSimplified "Adders.hs"
207       core <- GHC.compileToCoreModule filename
208       env <- GHC.getSession
209       return (core, env)
210       
211 -- | Loads the given file and turns it into a core module.
212 loadModuleAnn :: FilePath -> String -> IO (HscTypes.CoreModule, [CoreSyn.CoreBndr], [CoreSyn.CoreBndr], [CoreSyn.CoreExpr], HscTypes.HscEnv)
213 loadModuleAnn libdir filename =
214   defaultErrorHandler defaultDynFlags $ do
215     runGhc (Just libdir) $ do
216       dflags <- getSessionDynFlags
217       setSessionDynFlags dflags
218       --target <- guessTarget "adder.hs" Nothing
219       --liftIO (print (showSDoc (ppr (target))))
220       --liftIO $ printTarget target
221       --setTargets [target]
222       --load LoadAllTargets
223       --core <- GHC.compileToCoreSimplified "Adders.hs"
224       core <- GHC.compileToCoreModule filename
225       env <- GHC.getSession
226       top_entity <- findTopEntity core
227       init_state <- findInitState core
228       test_input <- findTestInput core
229       return (core, top_entity, init_state, test_input, env)
230
231 findTopEntity :: GhcMonad m => HscTypes.CoreModule -> m [CoreSyn.CoreBndr]
232 findTopEntity core = do
233   let binds = CoreSyn.flattenBinds $ cm_binds core
234   topbinds <- Monad.filterM (hasTopEntityAnnotation . fst) binds
235   let bndrs = case topbinds of [] -> error $ "Couldn't find top entity in current module." ; xs -> fst (unzip topbinds)
236   return bndrs
237   
238 findInitState :: GhcMonad m => HscTypes.CoreModule -> m [CoreSyn.CoreBndr]
239 findInitState core = do
240   let binds = CoreSyn.flattenBinds $ cm_binds core
241   statebinds <- Monad.filterM (hasInitStateAnnotation . fst) binds
242   let bndrs = case statebinds of [] -> [] ; xs -> fst (unzip statebinds)
243   return bndrs
244   
245 findTestInput :: GhcMonad m => HscTypes.CoreModule -> m [CoreSyn.CoreExpr]
246 findTestInput core = do
247   let binds = CoreSyn.flattenBinds $ cm_binds core
248   testbinds <- Monad.filterM (hasTestInputAnnotation . fst) binds
249   let exprs = case testbinds of [] -> [] ; xs -> snd (unzip testbinds)
250   return exprs
251   
252 hasTopEntityAnnotation :: GhcMonad m => Var.Var -> m Bool
253 hasTopEntityAnnotation var = do
254   let deserializer = Serialized.deserializeWithData
255   let target = Annotations.NamedTarget (Var.varName var)
256   (anns :: [CLasHAnn]) <- GHC.findGlobalAnns deserializer target
257   let top_ents = filter isTopEntity anns
258   case top_ents of
259     [] -> return False
260     xs -> return True
261     
262 hasInitStateAnnotation :: GhcMonad m => Var.Var -> m Bool
263 hasInitStateAnnotation var = do
264   let deserializer = Serialized.deserializeWithData
265   let target = Annotations.NamedTarget (Var.varName var)
266   (anns :: [CLasHAnn]) <- GHC.findGlobalAnns deserializer target
267   let top_ents = filter isInitState anns
268   case top_ents of
269     [] -> return False
270     xs -> return True
271     
272 hasTestInputAnnotation :: GhcMonad m => Var.Var -> m Bool
273 hasTestInputAnnotation var = do
274   let deserializer = Serialized.deserializeWithData
275   let target = Annotations.NamedTarget (Var.varName var)
276   (anns :: [CLasHAnn]) <- GHC.findGlobalAnns deserializer target
277   let top_ents = filter isTestInput anns
278   case top_ents of
279     [] -> return False
280     xs -> return True
281
282 -- | Extracts the named binds from the given module.
283 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
284 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
285
286 -- | Extract a named bind from the given list of binds
287 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
288 findBind binds lookfor =
289   -- This ignores Recs and compares the name of the bind with lookfor,
290   -- disregarding any namespaces in OccName and extra attributes in Name and
291   -- Var.
292   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
293
294 -- | Flattens the given bind into the given signature and adds it to the
295 --   session. Then (recursively) finds any functions it uses and does the same
296 --   with them.
297 -- flattenBind ::
298 --   HsFunction                         -- The signature to flatten into
299 --   -> (CoreBndr, CoreExpr)            -- The bind to flatten
300 --   -> TranslatorState ()
301 -- 
302 -- flattenBind hsfunc bind@(var, expr) = do
303 --   -- Flatten the function
304 --   let flatfunc = flattenFunction hsfunc bind
305 --   -- Propagate state variables
306 --   let flatfunc' = propagateState hsfunc flatfunc
307 --   -- Store the flat function in the session
308 --   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
309 --   -- Flatten any functions used
310 --   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
311 --   mapM_ resolvFunc used_hsfuncs
312
313 -- | Decide which incoming state variables will become state in the
314 --   given function, and which will be propagate to other applied
315 --   functions.
316 -- propagateState ::
317 --   HsFunction
318 --   -> FlatFunction
319 --   -> FlatFunction
320 -- 
321 -- propagateState hsfunc flatfunc =
322 --     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
323 --   where
324 --     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
325 --     states' = zip olds news
326 --     -- Find all signals used by all sigdefs
327 --     uses = concatMap sigDefUses (flat_defs flatfunc)
328 --     -- Find all signals that are used more than once (is there a
329 --     -- prettier way to do this?)
330 --     multiple_uses = uses List.\\ (List.nub uses)
331 --     -- Find the states whose "old state" signal is used only once
332 --     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
333 --     -- See if these single use states can be propagated
334 --     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
335 --     substate_sigs = concat substate_sigss
336 --     -- Mark any propagated state signals as SigSubState
337 --     sigs' = map 
338 --       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
339 --       (flat_sigs flatfunc)
340
341 -- | Propagate the state into a single function application.
342 -- propagateState' ::
343 --   [(SignalId, SignalId)]
344 --                       -- ^ TODO
345 --   -> SigDef           -- ^ The SigDef to process.
346 --   -> ([SignalId], SigDef) 
347 --                       -- ^ Any signal ids that should become substates,
348 --                       --   and the resulting application.
349 -- 
350 -- propagateState' states def =
351 --     if (is_FApp def) then
352 --       (our_old ++ our_new, def {appFunc = hsfunc'})
353 --     else
354 --       ([], def)
355 --   where
356 --     hsfunc = appFunc def
357 --     args = appArgs def
358 --     res = appRes def
359 --     our_states = filter our_state states
360 --     -- A state signal belongs in this function if the old state is
361 --     -- passed in, and the new state returned
362 --     our_state (old, new) =
363 --       any (old `Foldable.elem`) args
364 --       && new `Foldable.elem` res
365 --     (our_old, our_new) = unzip our_states
366 --     -- Mark the result
367 --     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
368 --     res' = fmap (mark_state (zip our_new [0..])) zipped_res
369 --     -- Mark the args
370 --     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
371 --     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
372 --     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
373 -- 
374 --     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
375 --     mark_state states (id, use) =
376 --       case lookup id states of
377 --         Nothing -> use
378 --         Just state_id -> State state_id
379
380 -- | Returns pairs of signals that should be mapped to state in this function.
381 -- getStateSignals ::
382 --   HsFunction                      -- | The function to look at
383 --   -> FlatFunction                 -- | The function to look at
384 --   -> [(SignalId, SignalId)]   
385 --         -- | TODO The state signals. The first is the state number, the second the
386 --         --   signal to assign the current state to, the last is the signal
387 --         --   that holds the new state.
388 -- 
389 -- getStateSignals hsfunc flatfunc =
390 --   [(old_id, new_id) 
391 --     | (old_num, old_id) <- args
392 --     , (new_num, new_id) <- res
393 --     , old_num == new_num]
394 --   where
395 --     sigs = flat_sigs flatfunc
396 --     -- Translate args and res to lists of (statenum, sigid)
397 --     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
398 --     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
399     
400 -- | Find the given function, flatten it and add it to the session. Then
401 --   (recursively) do the same for any functions used.
402 -- resolvFunc ::
403 --   HsFunction        -- | The function to look for
404 --   -> TranslatorState ()
405 -- 
406 -- resolvFunc hsfunc = do
407 --   flatfuncmap <- getA tsFlatFuncs
408 --   -- Don't do anything if there is already a flat function for this hsfunc or
409 --   -- when it is a builtin function.
410 --   Monad.unless (Map.member hsfunc flatfuncmap) $ do
411 --   -- Not working with new builtins -- Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
412 --   -- New function, resolve it
413 --   core <- getA tsCoreModule
414 --   -- Find the named function
415 --   let name = (hsFuncName hsfunc)
416 --   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
417 --   case bind of
418 --     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
419 --     Just b  -> flattenBind hsfunc b
420
421 -- | Translate a top level function declaration to a HsFunction. i.e., which
422 --   interface will be provided by this function. This function essentially
423 --   defines the "calling convention" for hardware models.
424 -- mkHsFunction ::
425 --   Var.Var         -- ^ The function defined
426 --   -> Type         -- ^ The function type (including arguments!)
427 --   -> Bool         -- ^ Is this a stateful function?
428 --   -> HsFunction   -- ^ The resulting HsFunction
429 -- 
430 -- mkHsFunction f ty stateful=
431 --   HsFunction hsname hsargs hsres
432 --   where
433 --     hsname  = getOccString f
434 --     (arg_tys, res_ty) = Type.splitFunTys ty
435 --     (hsargs, hsres) = 
436 --       if stateful 
437 --       then
438 --         let
439 --           -- The last argument must be state
440 --           state_ty = last arg_tys
441 --           state    = useAsState (mkHsValueMap state_ty)
442 --           -- All but the last argument are inports
443 --           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
444 --           hsargs   = inports ++ [state]
445 --           hsres    = case splitTupleType res_ty of
446 --             -- Result type must be a two tuple (state, ports)
447 --             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
448 --               then
449 --                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
450 --               else
451 --                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
452 --             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
453 --         in
454 --           (hsargs, hsres)
455 --       else
456 --         -- Just use everything as a port
457 --         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
458
459 -- | Adds signal names to the given FlatFunction
460 -- nameFlatFunction ::
461 --   FlatFunction
462 --   -> FlatFunction
463 -- 
464 -- nameFlatFunction flatfunc =
465 --   -- Name the signals
466 --   let 
467 --     s = flat_sigs flatfunc
468 --     s' = map nameSignal s in
469 --   flatfunc { flat_sigs = s' }
470 --   where
471 --     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
472 --     nameSignal (id, info) =
473 --       let hints = nameHints info in
474 --       let parts = ("sig" : hints) ++ [show id] in
475 --       let name = concat $ List.intersperse "_" parts in
476 --       (id, info {sigName = Just name})
477 -- 
478 -- -- | Splits a tuple type into a list of element types, or Nothing if the type
479 -- --   is not a tuple type.
480 -- splitTupleType ::
481 --   Type              -- ^ The type to split
482 --   -> Maybe [Type]   -- ^ The tuples element types
483 -- 
484 -- splitTupleType ty =
485 --   case Type.splitTyConApp_maybe ty of
486 --     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
487 --       then
488 --         Just args
489 --       else
490 --         Nothing
491 --     Nothing -> Nothing
492
493 -- vim: set ts=8 sw=2 sts=2 expandtab: