Great speed-up in type generation
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import qualified Directory
3 import qualified System.FilePath as FilePath
4 import qualified List
5 import Debug.Trace
6 import qualified Control.Arrow as Arrow
7 import GHC hiding (loadModule, sigName)
8 import CoreSyn
9 import qualified CoreUtils
10 import qualified Var
11 import qualified Type
12 import qualified TyCon
13 import qualified DataCon
14 import qualified HscMain
15 import qualified SrcLoc
16 import qualified FastString
17 import qualified Maybe
18 import qualified Module
19 import qualified Data.Foldable as Foldable
20 import qualified Control.Monad.Trans.State as State
21 import Name
22 import qualified Data.Map as Map
23 import Data.Accessor
24 import Data.Generics
25 import NameEnv ( lookupNameEnv )
26 import qualified HscTypes
27 import HscTypes ( cm_binds, cm_types )
28 import MonadUtils ( liftIO )
29 import Outputable ( showSDoc, ppr, showSDocDebug )
30 import GHC.Paths ( libdir )
31 import DynFlags ( defaultDynFlags )
32 import qualified UniqSupply
33 import List ( find )
34 import qualified List
35 import qualified Monad
36
37 -- The following modules come from the ForSyDe project. They are really
38 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
39 -- ForSyDe to get access to these modules.
40 import qualified ForSyDe.Backend.VHDL.AST as AST
41 import qualified ForSyDe.Backend.VHDL.Ppr
42 import qualified ForSyDe.Backend.VHDL.FileIO
43 import qualified ForSyDe.Backend.Ppr
44 -- This is needed for rendering the pretty printed VHDL
45 import Text.PrettyPrint.HughesPJ (render)
46
47 import TranslatorTypes
48 import HsValueMap
49 import Pretty
50 import Normalize
51 -- import Flatten
52 -- import FlattenTypes
53 import VHDLTypes
54 import qualified VHDL
55
56 makeVHDL :: String -> String -> Bool -> IO ()
57 makeVHDL filename name stateful = do
58   -- Load the module
59   (core, env) <- loadModule filename
60   -- Translate to VHDL
61   vhdl <- moduleToVHDL env core [(name, stateful)]
62   -- Write VHDL to file
63   let dir = "./vhdl/" ++ name ++ "/"
64   prepareDir dir
65   mapM (writeVHDL dir) vhdl
66   return ()
67
68 listBindings :: String -> IO [()]
69 listBindings filename = do
70   (core, env) <- loadModule filename
71   let binds = CoreSyn.flattenBinds $ cm_binds core
72   mapM (listBinding) binds
73
74 listBinding :: (CoreBndr, CoreExpr) -> IO ()
75 listBinding (b, e) = do
76   putStr "\nBinder: "
77   putStr $ show b
78   putStr "\nExpression: \n"
79   putStr $ prettyShow e
80   putStr "\n\n"
81   putStr $ showSDoc $ ppr e
82   putStr "\n\n"
83   putStr $ showSDoc $ ppr $ CoreUtils.exprType e
84   putStr "\n\n"
85   
86 -- | Show the core structure of the given binds in the given file.
87 listBind :: String -> String -> IO ()
88 listBind filename name = do
89   (core, env) <- loadModule filename
90   let [(b, expr)] = findBinds core [name]
91   putStr "\n"
92   putStr $ prettyShow expr
93   putStr "\n\n"
94   putStr $ showSDoc $ ppr expr
95   putStr "\n\n"
96   putStr $ showSDoc $ ppr $ CoreUtils.exprType expr
97   putStr "\n\n"
98
99 -- | Translate the binds with the given names from the given core module to
100 --   VHDL. The Bool in the tuple makes the function stateful (True) or
101 --   stateless (False).
102 moduleToVHDL :: HscTypes.HscEnv -> HscTypes.CoreModule -> [(String, Bool)] -> IO [(AST.VHDLId, AST.DesignFile)]
103 moduleToVHDL env core list = do
104   let (names, statefuls) = unzip list
105   let binds = map fst $ findBinds core names
106   -- Generate a UniqSupply
107   -- Running 
108   --    egrep -r "(initTcRnIf|mkSplitUniqSupply)" .
109   -- on the compiler dir of ghc suggests that 'z' is not used to generate a
110   -- unique supply anywhere.
111   uniqSupply <- UniqSupply.mkSplitUniqSupply 'z'
112   -- Turn bind into VHDL
113   let all_bindings = (CoreSyn.flattenBinds $ cm_binds core)
114   let (normalized_bindings, typestate) = normalizeModule env uniqSupply all_bindings binds statefuls
115   let vhdl = VHDL.createDesignFiles typestate normalized_bindings
116   mapM (putStr . render . ForSyDe.Backend.Ppr.ppr . snd) vhdl
117   --putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
118   return vhdl
119   where
120
121 -- | Prepares the directory for writing VHDL files. This means creating the
122 --   dir if it does not exist and removing all existing .vhdl files from it.
123 prepareDir :: String -> IO()
124 prepareDir dir = do
125   -- Create the dir if needed
126   exists <- Directory.doesDirectoryExist dir
127   Monad.unless exists $ Directory.createDirectory dir
128   -- Find all .vhdl files in the directory
129   files <- Directory.getDirectoryContents dir
130   let to_remove = filter ((==".vhdl") . FilePath.takeExtension) files
131   -- Prepend the dirname to the filenames
132   let abs_to_remove = map (FilePath.combine dir) to_remove
133   -- Remove the files
134   mapM_ Directory.removeFile abs_to_remove
135
136 -- | Write the given design file to a file with the given name inside the
137 --   given dir
138 writeVHDL :: String -> (AST.VHDLId, AST.DesignFile) -> IO ()
139 writeVHDL dir (name, vhdl) = do
140   -- Find the filename
141   let fname = dir ++ (AST.fromVHDLId name) ++ ".vhdl"
142   -- Write the file
143   ForSyDe.Backend.VHDL.FileIO.writeDesignFile vhdl fname
144
145 -- | Loads the given file and turns it into a core module.
146 loadModule :: String -> IO (HscTypes.CoreModule, HscTypes.HscEnv)
147 loadModule filename =
148   defaultErrorHandler defaultDynFlags $ do
149     runGhc (Just libdir) $ do
150       dflags <- getSessionDynFlags
151       setSessionDynFlags dflags
152       --target <- guessTarget "adder.hs" Nothing
153       --liftIO (print (showSDoc (ppr (target))))
154       --liftIO $ printTarget target
155       --setTargets [target]
156       --load LoadAllTargets
157       --core <- GHC.compileToCoreSimplified "Adders.hs"
158       core <- GHC.compileToCoreModule filename
159       env <- GHC.getSession
160       return (core, env)
161
162 -- | Extracts the named binds from the given module.
163 findBinds :: HscTypes.CoreModule -> [String] -> [(CoreBndr, CoreExpr)]
164 findBinds core names = Maybe.mapMaybe (findBind (CoreSyn.flattenBinds $ cm_binds core)) names
165
166 -- | Extract a named bind from the given list of binds
167 findBind :: [(CoreBndr, CoreExpr)] -> String -> Maybe (CoreBndr, CoreExpr)
168 findBind binds lookfor =
169   -- This ignores Recs and compares the name of the bind with lookfor,
170   -- disregarding any namespaces in OccName and extra attributes in Name and
171   -- Var.
172   find (\(var, _) -> lookfor == (occNameString $ nameOccName $ getName var)) binds
173
174 -- | Flattens the given bind into the given signature and adds it to the
175 --   session. Then (recursively) finds any functions it uses and does the same
176 --   with them.
177 -- flattenBind ::
178 --   HsFunction                         -- The signature to flatten into
179 --   -> (CoreBndr, CoreExpr)            -- The bind to flatten
180 --   -> TranslatorState ()
181 -- 
182 -- flattenBind hsfunc bind@(var, expr) = do
183 --   -- Flatten the function
184 --   let flatfunc = flattenFunction hsfunc bind
185 --   -- Propagate state variables
186 --   let flatfunc' = propagateState hsfunc flatfunc
187 --   -- Store the flat function in the session
188 --   modA tsFlatFuncs (Map.insert hsfunc flatfunc')
189 --   -- Flatten any functions used
190 --   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc')
191 --   mapM_ resolvFunc used_hsfuncs
192
193 -- | Decide which incoming state variables will become state in the
194 --   given function, and which will be propagate to other applied
195 --   functions.
196 -- propagateState ::
197 --   HsFunction
198 --   -> FlatFunction
199 --   -> FlatFunction
200 -- 
201 -- propagateState hsfunc flatfunc =
202 --     flatfunc {flat_defs = apps', flat_sigs = sigs'} 
203 --   where
204 --     (olds, news) = unzip $ getStateSignals hsfunc flatfunc
205 --     states' = zip olds news
206 --     -- Find all signals used by all sigdefs
207 --     uses = concatMap sigDefUses (flat_defs flatfunc)
208 --     -- Find all signals that are used more than once (is there a
209 --     -- prettier way to do this?)
210 --     multiple_uses = uses List.\\ (List.nub uses)
211 --     -- Find the states whose "old state" signal is used only once
212 --     single_use_states = filter ((`notElem` multiple_uses) . fst) states'
213 --     -- See if these single use states can be propagated
214 --     (substate_sigss, apps') = unzip $ map (propagateState' single_use_states) (flat_defs flatfunc)
215 --     substate_sigs = concat substate_sigss
216 --     -- Mark any propagated state signals as SigSubState
217 --     sigs' = map 
218 --       (\(id, info) -> (id, if id `elem` substate_sigs then info {sigUse = SigSubState} else info))
219 --       (flat_sigs flatfunc)
220
221 -- | Propagate the state into a single function application.
222 -- propagateState' ::
223 --   [(SignalId, SignalId)]
224 --                       -- ^ TODO
225 --   -> SigDef           -- ^ The SigDef to process.
226 --   -> ([SignalId], SigDef) 
227 --                       -- ^ Any signal ids that should become substates,
228 --                       --   and the resulting application.
229 -- 
230 -- propagateState' states def =
231 --     if (is_FApp def) then
232 --       (our_old ++ our_new, def {appFunc = hsfunc'})
233 --     else
234 --       ([], def)
235 --   where
236 --     hsfunc = appFunc def
237 --     args = appArgs def
238 --     res = appRes def
239 --     our_states = filter our_state states
240 --     -- A state signal belongs in this function if the old state is
241 --     -- passed in, and the new state returned
242 --     our_state (old, new) =
243 --       any (old `Foldable.elem`) args
244 --       && new `Foldable.elem` res
245 --     (our_old, our_new) = unzip our_states
246 --     -- Mark the result
247 --     zipped_res = zipValueMaps res (hsFuncRes hsfunc)
248 --     res' = fmap (mark_state (zip our_new [0..])) zipped_res
249 --     -- Mark the args
250 --     zipped_args = zipWith zipValueMaps args (hsFuncArgs hsfunc)
251 --     args' = map (fmap (mark_state (zip our_old [0..]))) zipped_args
252 --     hsfunc' = hsfunc {hsFuncArgs = args', hsFuncRes = res'}
253 -- 
254 --     mark_state :: [(SignalId, StateId)] -> (SignalId, HsValueUse) -> HsValueUse
255 --     mark_state states (id, use) =
256 --       case lookup id states of
257 --         Nothing -> use
258 --         Just state_id -> State state_id
259
260 -- | Returns pairs of signals that should be mapped to state in this function.
261 -- getStateSignals ::
262 --   HsFunction                      -- | The function to look at
263 --   -> FlatFunction                 -- | The function to look at
264 --   -> [(SignalId, SignalId)]   
265 --         -- | TODO The state signals. The first is the state number, the second the
266 --         --   signal to assign the current state to, the last is the signal
267 --         --   that holds the new state.
268 -- 
269 -- getStateSignals hsfunc flatfunc =
270 --   [(old_id, new_id) 
271 --     | (old_num, old_id) <- args
272 --     , (new_num, new_id) <- res
273 --     , old_num == new_num]
274 --   where
275 --     sigs = flat_sigs flatfunc
276 --     -- Translate args and res to lists of (statenum, sigid)
277 --     args = concat $ zipWith stateList (hsFuncArgs hsfunc) (flat_args flatfunc)
278 --     res = stateList (hsFuncRes hsfunc) (flat_res flatfunc)
279     
280 -- | Find the given function, flatten it and add it to the session. Then
281 --   (recursively) do the same for any functions used.
282 -- resolvFunc ::
283 --   HsFunction        -- | The function to look for
284 --   -> TranslatorState ()
285 -- 
286 -- resolvFunc hsfunc = do
287 --   flatfuncmap <- getA tsFlatFuncs
288 --   -- Don't do anything if there is already a flat function for this hsfunc or
289 --   -- when it is a builtin function.
290 --   Monad.unless (Map.member hsfunc flatfuncmap) $ do
291 --   -- Not working with new builtins -- Monad.unless (elem hsfunc VHDL.builtin_hsfuncs) $ do
292 --   -- New function, resolve it
293 --   core <- getA tsCoreModule
294 --   -- Find the named function
295 --   let name = (hsFuncName hsfunc)
296 --   let bind = findBind (CoreSyn.flattenBinds $ cm_binds core) name 
297 --   case bind of
298 --     Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
299 --     Just b  -> flattenBind hsfunc b
300
301 -- | Translate a top level function declaration to a HsFunction. i.e., which
302 --   interface will be provided by this function. This function essentially
303 --   defines the "calling convention" for hardware models.
304 -- mkHsFunction ::
305 --   Var.Var         -- ^ The function defined
306 --   -> Type         -- ^ The function type (including arguments!)
307 --   -> Bool         -- ^ Is this a stateful function?
308 --   -> HsFunction   -- ^ The resulting HsFunction
309 -- 
310 -- mkHsFunction f ty stateful=
311 --   HsFunction hsname hsargs hsres
312 --   where
313 --     hsname  = getOccString f
314 --     (arg_tys, res_ty) = Type.splitFunTys ty
315 --     (hsargs, hsres) = 
316 --       if stateful 
317 --       then
318 --         let
319 --           -- The last argument must be state
320 --           state_ty = last arg_tys
321 --           state    = useAsState (mkHsValueMap state_ty)
322 --           -- All but the last argument are inports
323 --           inports = map (useAsPort . mkHsValueMap)(init arg_tys)
324 --           hsargs   = inports ++ [state]
325 --           hsres    = case splitTupleType res_ty of
326 --             -- Result type must be a two tuple (state, ports)
327 --             Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
328 --               then
329 --                 Tuple [state, useAsPort (mkHsValueMap outport_ty)]
330 --               else
331 --                 error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
332 --             otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
333 --         in
334 --           (hsargs, hsres)
335 --       else
336 --         -- Just use everything as a port
337 --         (map (useAsPort . mkHsValueMap) arg_tys, useAsPort $ mkHsValueMap res_ty)
338
339 -- | Adds signal names to the given FlatFunction
340 -- nameFlatFunction ::
341 --   FlatFunction
342 --   -> FlatFunction
343 -- 
344 -- nameFlatFunction flatfunc =
345 --   -- Name the signals
346 --   let 
347 --     s = flat_sigs flatfunc
348 --     s' = map nameSignal s in
349 --   flatfunc { flat_sigs = s' }
350 --   where
351 --     nameSignal :: (SignalId, SignalInfo) -> (SignalId, SignalInfo)
352 --     nameSignal (id, info) =
353 --       let hints = nameHints info in
354 --       let parts = ("sig" : hints) ++ [show id] in
355 --       let name = concat $ List.intersperse "_" parts in
356 --       (id, info {sigName = Just name})
357 -- 
358 -- -- | Splits a tuple type into a list of element types, or Nothing if the type
359 -- --   is not a tuple type.
360 -- splitTupleType ::
361 --   Type              -- ^ The type to split
362 --   -> Maybe [Type]   -- ^ The tuples element types
363 -- 
364 -- splitTupleType ty =
365 --   case Type.splitTyConApp_maybe ty of
366 --     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
367 --       then
368 --         Just args
369 --       else
370 --         Nothing
371 --     Nothing -> Nothing
372
373 -- vim: set ts=8 sw=2 sts=2 expandtab: