Removed need for GHC.Paths, some functions however require a top libdir
[matthijs/master-project/cλash.git] / cλash / CLasH / Translator.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2
3 module CLasH.Translator where
4
5 import qualified Directory
6 import qualified System.FilePath as FilePath
7 import qualified List
8 import Debug.Trace
9 import qualified Control.Arrow as Arrow
10 import GHC hiding (loadModule, sigName)
11 import CoreSyn
12 import qualified CoreUtils
13 import qualified Var
14 import qualified Type
15 import qualified TyCon
16 import qualified DataCon
17 import qualified HscMain
18 import qualified SrcLoc
19 import qualified FastString
20 import qualified Maybe
21 import qualified Module
22 import qualified Data.Foldable as Foldable
23 import qualified Control.Monad.Trans.State as State
24 import qualified Control.Monad as Monad
25 import Name
26 import qualified Data.Map as Map
27 import Data.Accessor
28 import Data.Generics
29 import NameEnv ( lookupNameEnv )
30 import qualified HscTypes
31 import HscTypes ( cm_binds, cm_types )
32 import MonadUtils ( liftIO )
33 import Outputable ( showSDoc, ppr, showSDocDebug )
34 import DynFlags ( defaultDynFlags )
35 import qualified UniqSupply
36 import List ( find )
37 import qualified List
38 import qualified Monad
39 import qualified Annotations
40 import qualified Serialized
41
42 -- The following modules come from the ForSyDe project. They are really
43 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
44 -- ForSyDe to get access to these modules.
45 import qualified Language.VHDL.AST as AST
46 import qualified Language.VHDL.FileIO
47 import qualified Language.VHDL.Ppr as Ppr
48 -- This is needed for rendering the pretty printed VHDL
49 import Text.PrettyPrint.HughesPJ (render)
50
51 import CLasH.Translator.TranslatorTypes
52 import CLasH.Translator.Annotations
53 import CLasH.Utils.Pretty
54 import CLasH.Normalize
55 import CLasH.VHDL.VHDLTypes
56 import qualified CLasH.VHDL as VHDL
57
58 makeVHDL :: FilePath -> String -> String -> Bool -> IO ()
59 makeVHDL libdir filename name stateful = do
60   -- Load the module
61   (core, env) <- loadModule libdir filename
62   -- Translate to VHDL
63   vhdl <- moduleToVHDL env core [(name, stateful)]
64   -- Write VHDL to file
65   let dir = "./vhdl/" ++ name ++ "/"
66   prepareDir dir
67   mapM (writeVHDL dir) vhdl
68   return ()
69   
70 makeVHDLAnn :: FilePath -> String -> IO ()
71 makeVHDLAnn libdir filename = do
72   (core, top, init, env) <- loadModuleAnn libdir filename
73   let top_entity = head top
74   vhdl <- case init of 
75     [] -> moduleToVHDLAnn env core [top_entity]
76     xs -> moduleToVHDLAnnState env core [(top_entity, (head xs))]
77   let dir = "./vhdl/" ++ (show top_entity) ++ "/"
78   prepareDir dir
79   mapM (writeVHDL dir) vhdl
80   return ()
81
82 listBindings :: FilePath -> String -> IO [()]
83 listBindings libdir filename = do
84   (core, env) <- loadModule libdir filename
85   let binds = CoreSyn.flattenBinds $ cm_binds core
86   mapM (listBinding) binds
87
88 listBinding :: (CoreBndr, CoreExpr) -> IO ()
89 listBinding (b, e) = do
90   putStr "\nBinder: "
91   putStr $ show b
92   putStr "\nExpression: \n"
93   putStr $ prettyShow e
94   putStr "\n\n"
95   putStr $ showSDoc $ ppr e
96   putStr "\n\n"
97   putStr $ showSDoc $ ppr $ CoreUtils.exprType e
98   putStr "\n\n"
99   
100 -- | Show the core structure of the given binds in the given file.
101 listBind :: FilePath -> String -> String -> IO ()
102 listBind libdir filename name = do
103   (core, env) <- loadModule libdir filename
104   let [(b, expr)] = findBinds core [name]
105   putStr "\n"
106   putStr $ prettyShow expr
107   putStr "\n\n"
108   putStr $ showSDoc $ ppr expr
109   putStr "\n\n"
110   putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
111   putStr "\n\n"
112
113 -- | Translate the binds with the given names from the given core module to
114 --   VHDL. The Bool in the tuple makes the function stateful (True) or
115 --   stateless (False).
116 moduleToVHDL :: HscTypes.HscEnv -> HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
117 moduleToVHDL env core list = do
118   let (names, statefuls) = unzip list
119   let binds = map fst $ findBinds core names
120   -- Generate a UniqSupply
121   -- Running 
122   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
123   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
124   -- unique supply anywhere.
125   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
126   -- Turn bind into VHDL
127   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
128   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds statefuls
129   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
130   mapM (putStr . render . Ppr.ppr . snd) vhdl
131   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
132   return vhdl
133   
134 moduleToVHDLAnn :: HscTypes.HscEnv -> HscTypes.CoreModule -> [CoreSyn.CoreBndr] -> IO [(AST.VHDLId, AST.DesignFile)]
135 moduleToVHDLAnn env core binds = do
136   -- Generate a UniqSupply
137   -- Running 
138   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
139   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
140   -- unique supply anywhere.
141   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
142   -- Turn bind into VHDL
143   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
144   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds [False]
145   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
146   mapM (putStr . render . Ppr.ppr . snd) vhdl
147   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
148   return vhdl
149   
150 moduleToVHDLAnnState :: HscTypes.HscEnv -> HscTypes.CoreModule -> [(CoreSyn.CoreBndr, CoreSyn.CoreBndr)] -> IO [(AST.VHDLId, AST.DesignFile)]
151 moduleToVHDLAnnState env core list = do
152   let (binds, init_states) = unzip list
153   -- Generate a UniqSupply
154   -- Running 
155   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
156   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
157   -- unique supply anywhere.
158   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
159   -- Turn bind into VHDL
160   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
161   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds [True]
162   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
163   mapM (putStr . render . Ppr.ppr . snd) vhdl
164   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
165   return vhdl
166
167 -- | Prepares the directory for writing VHDL files. This means creating the
168 --   dir if it does not exist and removing all existing .vhdl files from it.
169 prepareDir :: String -> IO()
170 prepareDir dir = do
171   -- Create the dir if needed
172   exists <- Directory.doesDirectoryExist dir
173   Monad.unless exists $ Directory.createDirectory dir
174   -- Find all .vhdl files in the directory
175   files <- Directory.getDirectoryContents dir
176   let to_remove = filter ((==".vhdl") . FilePath.takeExtension) files
177   -- Prepend the dirname to the filenames
178   let abs_to_remove = map (FilePath.combine dir) to_remove
179   -- Remove the files
180   mapM_ Directory.removeFile abs_to_remove
181
182 -- | Write the given design file to a file with the given name inside the
183 --   given dir
184 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
185 writeVHDL dir (name, vhdl) = do
186   -- Find the filename
187   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
188   -- Write the file
189   Language.VHDL.FileIO.writeDesignFile vhdl fname
190
191 -- | Loads the given file and turns it into a core module.
192 loadModule :: FilePath -> String -> IO (HscTypes.CoreModule, HscTypes.HscEnv)
193 loadModule libdir filename =
194   defaultErrorHandler defaultDynFlags $ do
195     runGhc (Just libdir) $ do
196       dflags <- getSessionDynFlags
197       setSessionDynFlags dflags
198       --target <- guessTarget "adder.hs" Nothing
199       --liftIO (print (showSDoc (ppr (target))))
200       --liftIO $ printTarget target
201       --setTargets [target]
202       --load LoadAllTargets
203       --core <- GHC.compileToCoreSimplified "Adders.hs"
204       core <- GHC.compileToCoreModule filename
205       env <- GHC.getSession
206       return (core, env)
207       
208 -- | Loads the given file and turns it into a core module.
209 loadModuleAnn :: FilePath -> String -> IO (HscTypes.CoreModule, [CoreSyn.CoreBndr], [CoreSyn.CoreBndr], HscTypes.HscEnv)
210 loadModuleAnn libdir filename =
211   defaultErrorHandler defaultDynFlags $ do
212     runGhc (Just libdir) $ do
213       dflags <- getSessionDynFlags
214       setSessionDynFlags dflags
215       --target <- guessTarget "adder.hs" Nothing
216       --liftIO (print (showSDoc (ppr (target))))
217       --liftIO $ printTarget target
218       --setTargets [target]
219       --load LoadAllTargets
220       --core <- GHC.compileToCoreSimplified "Adders.hs"
221       core <- GHC.compileToCoreModule filename
222       env <- GHC.getSession
223       top_entity <- findTopEntity core
224       init_state <- findInitState core
225       return (core, top_entity, init_state, env)
226
227 findTopEntity :: GhcMonad m => HscTypes.CoreModule -> m [CoreSyn.CoreBndr]
228 findTopEntity core = do
229   let binds = CoreSyn.flattenBinds $ cm_binds core
230   topbinds <- Monad.filterM (hasTopEntityAnnotation . fst) binds
231   let bndrs = case topbinds of [] -> error $ "Couldn't find top entity in current module." ; xs -> fst (unzip topbinds)
232   return bndrs
233   
234 findInitState :: GhcMonad m => HscTypes.CoreModule -> m [CoreSyn.CoreBndr]
235 findInitState core = do
236   let binds = CoreSyn.flattenBinds $ cm_binds core
237   statebinds <- Monad.filterM (hasInitStateAnnotation . fst) binds
238   let bndrs = case statebinds of [] -> [] ; xs -> fst (unzip statebinds)
239   return bndrs
240   
241 hasTopEntityAnnotation :: GhcMonad m => Var.Var -> m Bool
242 hasTopEntityAnnotation var = do
243   let deserializer = Serialized.deserializeWithData
244   let target = Annotations.NamedTarget (Var.varName var)
245   (anns :: [CLasHAnn]) <- GHC.findGlobalAnns deserializer target
246   let top_ents = filter isTopEntity anns
247   case top_ents of
248     [] -> return False
249     xs -> return True
250     
251 hasInitStateAnnotation :: GhcMonad m => Var.Var -> m Bool
252 hasInitStateAnnotation var = do
253   let deserializer = Serialized.deserializeWithData
254   let target = Annotations.NamedTarget (Var.varName var)
255   (anns :: [CLasHAnn]) <- GHC.findGlobalAnns deserializer target
256   let top_ents = filter isInitState anns
257   case top_ents of
258     [] -> return False
259     xs -> return True
260
261 -- | Extracts the named binds from the given module.
262 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
263 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
264
265 -- | Extract a named bind from the given list of binds
266 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
267 findBind binds lookfor =
268   -- This ignores Recs and compares the name of the bind with lookfor,
269   -- disregarding any namespaces in OccName and extra attributes in Name and
270   -- Var.
271   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
272
273 -- | Flattens the given bind into the given signature and adds it to the
274 --   session. Then (recursively) finds any functions it uses and does the same
275 --   with them.
276 -- flattenBind ::
277 --   HsFunction                         -- The signature to flatten into
278 --   -> (CoreBndr, CoreExpr)            -- The bind to flatten
279 --   -> TranslatorState ()
280 -- 
281 -- flattenBind hsfunc bind@(var, expr) = do
282 --   -- Flatten the function
283 --   let flatfunc = flattenFunction hsfunc bind
284 --   -- Propagate state variables
285 --   let flatfunc' = propagateState hsfunc flatfunc
286 --   -- Store the flat function in the session
287 --   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
288 --   -- Flatten any functions used
289 --   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
290 --   mapM_ resolvFunc used_hsfuncs
291
292 -- | Decide which incoming state variables will become state in the
293 --   given function, and which will be propagate to other applied
294 --   functions.
295 -- propagateState ::
296 --   HsFunction
297 --   -> FlatFunction
298 --   -> FlatFunction
299 -- 
300 -- propagateState hsfunc flatfunc =
301 --     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
302 --   where
303 --     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
304 --     states' = zip olds news
305 --     -- Find all signals used by all sigdefs
306 --     uses = concatMap sigDefUses (flat_defs flatfunc)
307 --     -- Find all signals that are used more than once (is there a
308 --     -- prettier way to do this?)
309 --     multiple_uses = uses List.\\ (List.nub uses)
310 --     -- Find the states whose "old state" signal is used only once
311 --     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
312 --     -- See if these single use states can be propagated
313 --     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
314 --     substate_sigs = concat substate_sigss
315 --     -- Mark any propagated state signals as SigSubState
316 --     sigs' = map 
317 --       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
318 --       (flat_sigs flatfunc)
319
320 -- | Propagate the state into a single function application.
321 -- propagateState' ::
322 --   [(SignalId, SignalId)]
323 --                       -- ^ TODO
324 --   -> SigDef           -- ^ The SigDef to process.
325 --   -> ([SignalId], SigDef) 
326 --                       -- ^ Any signal ids that should become substates,
327 --                       --   and the resulting application.
328 -- 
329 -- propagateState' states def =
330 --     if (is_FApp def) then
331 --       (our_old ++ our_new, def {appFunc = hsfunc'})
332 --     else
333 --       ([], def)
334 --   where
335 --     hsfunc = appFunc def
336 --     args = appArgs def
337 --     res = appRes def
338 --     our_states = filter our_state states
339 --     -- A state signal belongs in this function if the old state is
340 --     -- passed in, and the new state returned
341 --     our_state (old, new) =
342 --       any (old `Foldable.elem`) args
343 --       && new `Foldable.elem` res
344 --     (our_old, our_new) = unzip our_states
345 --     -- Mark the result
346 --     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
347 --     res' = fmap (mark_state (zip our_new [0..])) zipped_res
348 --     -- Mark the args
349 --     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
350 --     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
351 --     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
352 -- 
353 --     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
354 --     mark_state states (id, use) =
355 --       case lookup id states of
356 --         Nothing -> use
357 --         Just state_id -> State state_id
358
359 -- | Returns pairs of signals that should be mapped to state in this function.
360 -- getStateSignals ::
361 --   HsFunction                      -- | The function to look at
362 --   -> FlatFunction                 -- | The function to look at
363 --   -> [(SignalId, SignalId)]   
364 --         -- | TODO The state signals. The first is the state number, the second the
365 --         --   signal to assign the current state to, the last is the signal
366 --         --   that holds the new state.
367 -- 
368 -- getStateSignals hsfunc flatfunc =
369 --   [(old_id, new_id) 
370 --     | (old_num, old_id) <- args
371 --     , (new_num, new_id) <- res
372 --     , old_num == new_num]
373 --   where
374 --     sigs = flat_sigs flatfunc
375 --     -- Translate args and res to lists of (statenum, sigid)
376 --     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
377 --     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
378     
379 -- | Find the given function, flatten it and add it to the session. Then
380 --   (recursively) do the same for any functions used.
381 -- resolvFunc ::
382 --   HsFunction        -- | The function to look for
383 --   -> TranslatorState ()
384 -- 
385 -- resolvFunc hsfunc = do
386 --   flatfuncmap <- getA tsFlatFuncs
387 --   -- Don't do anything if there is already a flat function for this hsfunc or
388 --   -- when it is a builtin function.
389 --   Monad.unless (Map.member hsfunc flatfuncmap) $ do
390 --   -- Not working with new builtins -- Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
391 --   -- New function, resolve it
392 --   core <- getA tsCoreModule
393 --   -- Find the named function
394 --   let name = (hsFuncName hsfunc)
395 --   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
396 --   case bind of
397 --     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
398 --     Just b  -> flattenBind hsfunc b
399
400 -- | Translate a top level function declaration to a HsFunction. i.e., which
401 --   interface will be provided by this function. This function essentially
402 --   defines the "calling convention" for hardware models.
403 -- mkHsFunction ::
404 --   Var.Var         -- ^ The function defined
405 --   -> Type         -- ^ The function type (including arguments!)
406 --   -> Bool         -- ^ Is this a stateful function?
407 --   -> HsFunction   -- ^ The resulting HsFunction
408 -- 
409 -- mkHsFunction f ty stateful=
410 --   HsFunction hsname hsargs hsres
411 --   where
412 --     hsname  = getOccString f
413 --     (arg_tys, res_ty) = Type.splitFunTys ty
414 --     (hsargs, hsres) = 
415 --       if stateful 
416 --       then
417 --         let
418 --           -- The last argument must be state
419 --           state_ty = last arg_tys
420 --           state    = useAsState (mkHsValueMap state_ty)
421 --           -- All but the last argument are inports
422 --           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
423 --           hsargs   = inports ++ [state]
424 --           hsres    = case splitTupleType res_ty of
425 --             -- Result type must be a two tuple (state, ports)
426 --             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
427 --               then
428 --                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
429 --               else
430 --                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
431 --             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
432 --         in
433 --           (hsargs, hsres)
434 --       else
435 --         -- Just use everything as a port
436 --         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
437
438 -- | Adds signal names to the given FlatFunction
439 -- nameFlatFunction ::
440 --   FlatFunction
441 --   -> FlatFunction
442 -- 
443 -- nameFlatFunction flatfunc =
444 --   -- Name the signals
445 --   let 
446 --     s = flat_sigs flatfunc
447 --     s' = map nameSignal s in
448 --   flatfunc { flat_sigs = s' }
449 --   where
450 --     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
451 --     nameSignal (id, info) =
452 --       let hints = nameHints info in
453 --       let parts = ("sig" : hints) ++ [show id] in
454 --       let name = concat $ List.intersperse "_" parts in
455 --       (id, info {sigName = Just name})
456 -- 
457 -- -- | Splits a tuple type into a list of element types, or Nothing if the type
458 -- --   is not a tuple type.
459 -- splitTupleType ::
460 --   Type              -- ^ The type to split
461 --   -> Maybe [Type]   -- ^ The tuples element types
462 -- 
463 -- splitTupleType ty =
464 --   case Type.splitTyConApp_maybe ty of
465 --     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
466 --       then
467 --         Just args
468 --       else
469 --         Nothing
470 --     Nothing -> Nothing
471
472 -- vim: set ts=8 sw=2 sts=2 expandtab: