Show type of binder in listBinding.
[matthijs/master-project/cλash.git] / cλash / CLasH / Translator.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2
3 module CLasH.Translator where
4
5 import qualified Directory
6 import qualified System.FilePath as FilePath
7 import qualified List
8 import Debug.Trace
9 import qualified Control.Arrow as Arrow
10 import GHC hiding (loadModule, sigName)
11 import CoreSyn
12 import qualified CoreUtils
13 import qualified Var
14 import qualified Type
15 import qualified TyCon
16 import qualified DataCon
17 import qualified HscMain
18 import qualified SrcLoc
19 import qualified FastString
20 import qualified Maybe
21 import qualified Module
22 import qualified Data.Foldable as Foldable
23 import qualified Control.Monad.Trans.State as State
24 import qualified Control.Monad as Monad
25 import Name
26 import qualified Data.Map as Map
27 import Data.Accessor
28 import Data.Generics
29 import NameEnv ( lookupNameEnv )
30 import qualified HscTypes
31 import HscTypes ( cm_binds, cm_types )
32 import MonadUtils ( liftIO )
33 import Outputable ( showSDoc, ppr, showSDocDebug )
34 import DynFlags ( defaultDynFlags )
35 import qualified UniqSupply
36 import List ( find )
37 import qualified List
38 import qualified Monad
39 import qualified Annotations
40 import qualified Serialized
41
42 -- The following modules come from the ForSyDe project. They are really
43 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
44 -- ForSyDe to get access to these modules.
45 import qualified Language.VHDL.AST as AST
46 import qualified Language.VHDL.FileIO
47 import qualified Language.VHDL.Ppr as Ppr
48 -- This is needed for rendering the pretty printed VHDL
49 import Text.PrettyPrint.HughesPJ (render)
50
51 import CLasH.Translator.TranslatorTypes
52 import CLasH.Translator.Annotations
53 import CLasH.Utils.Pretty
54 import CLasH.Normalize
55 import CLasH.VHDL.VHDLTypes
56 import qualified CLasH.VHDL as VHDL
57
58 makeVHDL :: FilePath -> String -> String -> Bool -> IO ()
59 makeVHDL libdir filename name stateful = do
60   -- Load the module
61   (core, env) <- loadModule libdir filename
62   -- Translate to VHDL
63   vhdl <- moduleToVHDL env core [(name, stateful)]
64   -- Write VHDL to file
65   let dir = "./vhdl/" ++ name ++ "/"
66   prepareDir dir
67   mapM (writeVHDL dir) vhdl
68   return ()
69   
70 makeVHDLAnn :: FilePath -> String -> IO ()
71 makeVHDLAnn libdir filename = do
72   (core, top, init, env) <- loadModuleAnn libdir filename
73   let top_entity = head top
74   vhdl <- case init of 
75     [] -> moduleToVHDLAnn env core [top_entity]
76     xs -> moduleToVHDLAnnState env core [(top_entity, (head xs))]
77   let dir = "./vhdl/" ++ (show top_entity) ++ "/"
78   prepareDir dir
79   mapM (writeVHDL dir) vhdl
80   return ()
81
82 listBindings :: FilePath -> String -> IO [()]
83 listBindings libdir filename = do
84   (core, env) <- loadModule libdir filename
85   let binds = CoreSyn.flattenBinds $ cm_binds core
86   mapM (listBinding) binds
87
88 listBinding :: (CoreBndr, CoreExpr) -> IO ()
89 listBinding (b, e) = do
90   putStr "\nBinder: "
91   putStr $ show b
92   putStr "\nType of Binder: \n"
93   putStr $ showSDoc $ ppr $ Var.varType b
94   putStr "\n\nExpression: \n"
95   putStr $ prettyShow e
96   putStr "\n\n"
97   putStr $ showSDoc $ ppr e
98   putStr "\n\nType of Expression: \n"
99   putStr $ showSDoc $ ppr $ CoreUtils.exprType e
100   putStr "\n\n"
101   
102 -- | Show the core structure of the given binds in the given file.
103 listBind :: FilePath -> String -> String -> IO ()
104 listBind libdir filename name = do
105   (core, env) <- loadModule libdir filename
106   let [(b, expr)] = findBinds core [name]
107   listBinding (b, expr)
108
109 -- | Translate the binds with the given names from the given core module to
110 --   VHDL. The Bool in the tuple makes the function stateful (True) or
111 --   stateless (False).
112 moduleToVHDL :: HscTypes.HscEnv -> HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
113 moduleToVHDL env core list = do
114   let (names, statefuls) = unzip list
115   let binds = map fst $ findBinds core names
116   -- Generate a UniqSupply
117   -- Running 
118   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
119   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
120   -- unique supply anywhere.
121   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
122   -- Turn bind into VHDL
123   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
124   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds statefuls
125   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
126   mapM (putStr . render . Ppr.ppr . snd) vhdl
127   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
128   return vhdl
129   
130 moduleToVHDLAnn :: HscTypes.HscEnv -> HscTypes.CoreModule -> [CoreSyn.CoreBndr] -> IO [(AST.VHDLId, AST.DesignFile)]
131 moduleToVHDLAnn env core binds = do
132   -- Generate a UniqSupply
133   -- Running 
134   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
135   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
136   -- unique supply anywhere.
137   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
138   -- Turn bind into VHDL
139   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
140   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds [False]
141   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
142   mapM (putStr . render . Ppr.ppr . snd) vhdl
143   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
144   return vhdl
145   
146 moduleToVHDLAnnState :: HscTypes.HscEnv -> HscTypes.CoreModule -> [(CoreSyn.CoreBndr, CoreSyn.CoreBndr)] -> IO [(AST.VHDLId, AST.DesignFile)]
147 moduleToVHDLAnnState env core list = do
148   let (binds, init_states) = unzip list
149   -- Generate a UniqSupply
150   -- Running 
151   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
152   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
153   -- unique supply anywhere.
154   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
155   -- Turn bind into VHDL
156   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
157   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds [True]
158   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
159   mapM (putStr . render . Ppr.ppr . snd) vhdl
160   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
161   return vhdl
162
163 -- | Prepares the directory for writing VHDL files. This means creating the
164 --   dir if it does not exist and removing all existing .vhdl files from it.
165 prepareDir :: String -> IO()
166 prepareDir dir = do
167   -- Create the dir if needed
168   exists <- Directory.doesDirectoryExist dir
169   Monad.unless exists $ Directory.createDirectory dir
170   -- Find all .vhdl files in the directory
171   files <- Directory.getDirectoryContents dir
172   let to_remove = filter ((==".vhdl") . FilePath.takeExtension) files
173   -- Prepend the dirname to the filenames
174   let abs_to_remove = map (FilePath.combine dir) to_remove
175   -- Remove the files
176   mapM_ Directory.removeFile abs_to_remove
177
178 -- | Write the given design file to a file with the given name inside the
179 --   given dir
180 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
181 writeVHDL dir (name, vhdl) = do
182   -- Find the filename
183   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
184   -- Write the file
185   Language.VHDL.FileIO.writeDesignFile vhdl fname
186
187 -- | Loads the given file and turns it into a core module.
188 loadModule :: FilePath -> String -> IO (HscTypes.CoreModule, HscTypes.HscEnv)
189 loadModule libdir filename =
190   defaultErrorHandler defaultDynFlags $ do
191     runGhc (Just libdir) $ do
192       dflags <- getSessionDynFlags
193       setSessionDynFlags dflags
194       --target <- guessTarget "adder.hs" Nothing
195       --liftIO (print (showSDoc (ppr (target))))
196       --liftIO $ printTarget target
197       --setTargets [target]
198       --load LoadAllTargets
199       --core <- GHC.compileToCoreSimplified "Adders.hs"
200       core <- GHC.compileToCoreModule filename
201       env <- GHC.getSession
202       return (core, env)
203       
204 -- | Loads the given file and turns it into a core module.
205 loadModuleAnn :: FilePath -> String -> IO (HscTypes.CoreModule, [CoreSyn.CoreBndr], [CoreSyn.CoreBndr], HscTypes.HscEnv)
206 loadModuleAnn libdir filename =
207   defaultErrorHandler defaultDynFlags $ do
208     runGhc (Just libdir) $ do
209       dflags <- getSessionDynFlags
210       setSessionDynFlags dflags
211       --target <- guessTarget "adder.hs" Nothing
212       --liftIO (print (showSDoc (ppr (target))))
213       --liftIO $ printTarget target
214       --setTargets [target]
215       --load LoadAllTargets
216       --core <- GHC.compileToCoreSimplified "Adders.hs"
217       core <- GHC.compileToCoreModule filename
218       env <- GHC.getSession
219       top_entity <- findTopEntity core
220       init_state <- findInitState core
221       return (core, top_entity, init_state, env)
222
223 findTopEntity :: GhcMonad m => HscTypes.CoreModule -> m [CoreSyn.CoreBndr]
224 findTopEntity core = do
225   let binds = CoreSyn.flattenBinds $ cm_binds core
226   topbinds <- Monad.filterM (hasTopEntityAnnotation . fst) binds
227   let bndrs = case topbinds of [] -> error $ "Couldn't find top entity in current module." ; xs -> fst (unzip topbinds)
228   return bndrs
229   
230 findInitState :: GhcMonad m => HscTypes.CoreModule -> m [CoreSyn.CoreBndr]
231 findInitState core = do
232   let binds = CoreSyn.flattenBinds $ cm_binds core
233   statebinds <- Monad.filterM (hasInitStateAnnotation . fst) binds
234   let bndrs = case statebinds of [] -> [] ; xs -> fst (unzip statebinds)
235   return bndrs
236   
237 hasTopEntityAnnotation :: GhcMonad m => Var.Var -> m Bool
238 hasTopEntityAnnotation var = do
239   let deserializer = Serialized.deserializeWithData
240   let target = Annotations.NamedTarget (Var.varName var)
241   (anns :: [CLasHAnn]) <- GHC.findGlobalAnns deserializer target
242   let top_ents = filter isTopEntity anns
243   case top_ents of
244     [] -> return False
245     xs -> return True
246     
247 hasInitStateAnnotation :: GhcMonad m => Var.Var -> m Bool
248 hasInitStateAnnotation var = do
249   let deserializer = Serialized.deserializeWithData
250   let target = Annotations.NamedTarget (Var.varName var)
251   (anns :: [CLasHAnn]) <- GHC.findGlobalAnns deserializer target
252   let top_ents = filter isInitState anns
253   case top_ents of
254     [] -> return False
255     xs -> return True
256
257 -- | Extracts the named binds from the given module.
258 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
259 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
260
261 -- | Extract a named bind from the given list of binds
262 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
263 findBind binds lookfor =
264   -- This ignores Recs and compares the name of the bind with lookfor,
265   -- disregarding any namespaces in OccName and extra attributes in Name and
266   -- Var.
267   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
268
269 -- | Flattens the given bind into the given signature and adds it to the
270 --   session. Then (recursively) finds any functions it uses and does the same
271 --   with them.
272 -- flattenBind ::
273 --   HsFunction                         -- The signature to flatten into
274 --   -> (CoreBndr, CoreExpr)            -- The bind to flatten
275 --   -> TranslatorState ()
276 -- 
277 -- flattenBind hsfunc bind@(var, expr) = do
278 --   -- Flatten the function
279 --   let flatfunc = flattenFunction hsfunc bind
280 --   -- Propagate state variables
281 --   let flatfunc' = propagateState hsfunc flatfunc
282 --   -- Store the flat function in the session
283 --   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
284 --   -- Flatten any functions used
285 --   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
286 --   mapM_ resolvFunc used_hsfuncs
287
288 -- | Decide which incoming state variables will become state in the
289 --   given function, and which will be propagate to other applied
290 --   functions.
291 -- propagateState ::
292 --   HsFunction
293 --   -> FlatFunction
294 --   -> FlatFunction
295 -- 
296 -- propagateState hsfunc flatfunc =
297 --     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
298 --   where
299 --     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
300 --     states' = zip olds news
301 --     -- Find all signals used by all sigdefs
302 --     uses = concatMap sigDefUses (flat_defs flatfunc)
303 --     -- Find all signals that are used more than once (is there a
304 --     -- prettier way to do this?)
305 --     multiple_uses = uses List.\\ (List.nub uses)
306 --     -- Find the states whose "old state" signal is used only once
307 --     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
308 --     -- See if these single use states can be propagated
309 --     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
310 --     substate_sigs = concat substate_sigss
311 --     -- Mark any propagated state signals as SigSubState
312 --     sigs' = map 
313 --       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
314 --       (flat_sigs flatfunc)
315
316 -- | Propagate the state into a single function application.
317 -- propagateState' ::
318 --   [(SignalId, SignalId)]
319 --                       -- ^ TODO
320 --   -> SigDef           -- ^ The SigDef to process.
321 --   -> ([SignalId], SigDef) 
322 --                       -- ^ Any signal ids that should become substates,
323 --                       --   and the resulting application.
324 -- 
325 -- propagateState' states def =
326 --     if (is_FApp def) then
327 --       (our_old ++ our_new, def {appFunc = hsfunc'})
328 --     else
329 --       ([], def)
330 --   where
331 --     hsfunc = appFunc def
332 --     args = appArgs def
333 --     res = appRes def
334 --     our_states = filter our_state states
335 --     -- A state signal belongs in this function if the old state is
336 --     -- passed in, and the new state returned
337 --     our_state (old, new) =
338 --       any (old `Foldable.elem`) args
339 --       && new `Foldable.elem` res
340 --     (our_old, our_new) = unzip our_states
341 --     -- Mark the result
342 --     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
343 --     res' = fmap (mark_state (zip our_new [0..])) zipped_res
344 --     -- Mark the args
345 --     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
346 --     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
347 --     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
348 -- 
349 --     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
350 --     mark_state states (id, use) =
351 --       case lookup id states of
352 --         Nothing -> use
353 --         Just state_id -> State state_id
354
355 -- | Returns pairs of signals that should be mapped to state in this function.
356 -- getStateSignals ::
357 --   HsFunction                      -- | The function to look at
358 --   -> FlatFunction                 -- | The function to look at
359 --   -> [(SignalId, SignalId)]   
360 --         -- | TODO The state signals. The first is the state number, the second the
361 --         --   signal to assign the current state to, the last is the signal
362 --         --   that holds the new state.
363 -- 
364 -- getStateSignals hsfunc flatfunc =
365 --   [(old_id, new_id) 
366 --     | (old_num, old_id) <- args
367 --     , (new_num, new_id) <- res
368 --     , old_num == new_num]
369 --   where
370 --     sigs = flat_sigs flatfunc
371 --     -- Translate args and res to lists of (statenum, sigid)
372 --     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
373 --     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
374     
375 -- | Find the given function, flatten it and add it to the session. Then
376 --   (recursively) do the same for any functions used.
377 -- resolvFunc ::
378 --   HsFunction        -- | The function to look for
379 --   -> TranslatorState ()
380 -- 
381 -- resolvFunc hsfunc = do
382 --   flatfuncmap <- getA tsFlatFuncs
383 --   -- Don't do anything if there is already a flat function for this hsfunc or
384 --   -- when it is a builtin function.
385 --   Monad.unless (Map.member hsfunc flatfuncmap) $ do
386 --   -- Not working with new builtins -- Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
387 --   -- New function, resolve it
388 --   core <- getA tsCoreModule
389 --   -- Find the named function
390 --   let name = (hsFuncName hsfunc)
391 --   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
392 --   case bind of
393 --     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
394 --     Just b  -> flattenBind hsfunc b
395
396 -- | Translate a top level function declaration to a HsFunction. i.e., which
397 --   interface will be provided by this function. This function essentially
398 --   defines the "calling convention" for hardware models.
399 -- mkHsFunction ::
400 --   Var.Var         -- ^ The function defined
401 --   -> Type         -- ^ The function type (including arguments!)
402 --   -> Bool         -- ^ Is this a stateful function?
403 --   -> HsFunction   -- ^ The resulting HsFunction
404 -- 
405 -- mkHsFunction f ty stateful=
406 --   HsFunction hsname hsargs hsres
407 --   where
408 --     hsname  = getOccString f
409 --     (arg_tys, res_ty) = Type.splitFunTys ty
410 --     (hsargs, hsres) = 
411 --       if stateful 
412 --       then
413 --         let
414 --           -- The last argument must be state
415 --           state_ty = last arg_tys
416 --           state    = useAsState (mkHsValueMap state_ty)
417 --           -- All but the last argument are inports
418 --           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
419 --           hsargs   = inports ++ [state]
420 --           hsres    = case splitTupleType res_ty of
421 --             -- Result type must be a two tuple (state, ports)
422 --             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
423 --               then
424 --                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
425 --               else
426 --                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
427 --             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
428 --         in
429 --           (hsargs, hsres)
430 --       else
431 --         -- Just use everything as a port
432 --         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
433
434 -- | Adds signal names to the given FlatFunction
435 -- nameFlatFunction ::
436 --   FlatFunction
437 --   -> FlatFunction
438 -- 
439 -- nameFlatFunction flatfunc =
440 --   -- Name the signals
441 --   let 
442 --     s = flat_sigs flatfunc
443 --     s' = map nameSignal s in
444 --   flatfunc { flat_sigs = s' }
445 --   where
446 --     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
447 --     nameSignal (id, info) =
448 --       let hints = nameHints info in
449 --       let parts = ("sig" : hints) ++ [show id] in
450 --       let name = concat $ List.intersperse "_" parts in
451 --       (id, info {sigName = Just name})
452 -- 
453 -- -- | Splits a tuple type into a list of element types, or Nothing if the type
454 -- --   is not a tuple type.
455 -- splitTupleType ::
456 --   Type              -- ^ The type to split
457 --   -> Maybe [Type]   -- ^ The tuples element types
458 -- 
459 -- splitTupleType ty =
460 --   case Type.splitTyConApp_maybe ty of
461 --     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
462 --       then
463 --         Just args
464 --       else
465 --         Nothing
466 --     Nothing -> Nothing
467
468 -- vim: set ts=8 sw=2 sts=2 expandtab: