ddd09fc340da9a401a24e1399c797de60c238234
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import GHC hiding (loadModule)
3 import CoreSyn
4 import qualified CoreUtils
5 import qualified Var
6 import qualified Type
7 import qualified TyCon
8 import qualified DataCon
9 import qualified Maybe
10 import qualified Module
11 import qualified Control.Monad.State as State
12 import Name
13 import qualified Data.Map as Map
14 import Data.Generics
15 import NameEnv ( lookupNameEnv )
16 import qualified HscTypes
17 import HscTypes ( cm_binds, cm_types )
18 import MonadUtils ( liftIO )
19 import Outputable ( showSDoc, ppr )
20 import GHC.Paths ( libdir )
21 import DynFlags ( defaultDynFlags )
22 import List ( find )
23 import qualified List
24 import qualified Monad
25
26 -- The following modules come from the ForSyDe project. They are really
27 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
28 -- ForSyDe to get access to these modules.
29 import qualified ForSyDe.Backend.VHDL.AST as AST
30 import qualified ForSyDe.Backend.VHDL.Ppr
31 import qualified ForSyDe.Backend.VHDL.FileIO
32 import qualified ForSyDe.Backend.Ppr
33 -- This is needed for rendering the pretty printed VHDL
34 import Text.PrettyPrint.HughesPJ (render)
35
36 import TranslatorTypes
37 import HsValueMap
38 import Pretty
39 import Flatten
40 import FlattenTypes
41 import VHDLTypes
42 import qualified VHDL
43
44 main = do
45   makeVHDL "Alu.hs" "salu"
46
47 makeVHDL :: String -> String -> IO ()
48 makeVHDL filename name = do
49   -- Load the module
50   core <- loadModule filename
51   -- Translate to VHDL
52   vhdl <- moduleToVHDL core [name]
53   -- Write VHDL to file
54   writeVHDL vhdl "../vhdl/vhdl/output.vhdl"
55
56 -- | Show the core structure of the given binds in the given file.
57 listBind :: String -> String -> IO ()
58 listBind filename name = do
59   core <- loadModule filename
60   let binds = findBinds core [name]
61   putStr "\n"
62   putStr $ prettyShow binds
63   putStr "\n\n"
64
65 -- | Translate the binds with the given names from the given core module to
66 --   VHDL
67 moduleToVHDL :: HscTypes.CoreModule -> [String] -> IO AST.DesignFile
68 moduleToVHDL core names = do
69   --liftIO $ putStr $ prettyShow (cm_binds core)
70   let binds = findBinds core names
71   --putStr $ prettyShow binds
72   -- Turn bind into VHDL
73   let (vhdl, sess) = State.runState (mkVHDL binds) (VHDLSession core 0 Map.empty)
74   putStr $ render $ ForSyDe.Backend.Ppr.ppr vhdl
75   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
76   return vhdl
77
78   where
79     -- Turns the given bind into VHDL
80     mkVHDL binds = do
81       -- Add the builtin functions
82       mapM addBuiltIn builtin_funcs
83       -- Create entities and architectures for them
84       mapM processBind binds
85       modFuncs nameFlatFunction
86       modFuncs VHDL.createEntity
87       modFuncs VHDL.createArchitecture
88       VHDL.getDesignFile
89
90 -- | Write the given design file to the given file
91 writeVHDL :: AST.DesignFile -> String -> IO ()
92 writeVHDL = ForSyDe.Backend.VHDL.FileIO.writeDesignFile
93
94 -- | Loads the given file and turns it into a core module.
95 loadModule :: String -> IO HscTypes.CoreModule
96 loadModule filename =
97   defaultErrorHandler defaultDynFlags $ do
98     runGhc (Just libdir) $ do
99       dflags <- getSessionDynFlags
100       setSessionDynFlags dflags
101       --target <- guessTarget "adder.hs" Nothing
102       --liftIO (print (showSDoc (ppr (target))))
103       --liftIO $ printTarget target
104       --setTargets [target]
105       --load LoadAllTargets
106       --core <- GHC.compileToCoreSimplified "Adders.hs"
107       core <- GHC.compileToCoreSimplified filename
108       return core
109
110 -- | Extracts the named binds from the given module.
111 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
112 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
113
114 -- | Extract a named bind from the given list of binds
115 findBind :: [CoreBind] -> String -> Maybe CoreBind
116 findBind binds lookfor =
117   -- This ignores Recs and compares the name of the bind with lookfor,
118   -- disregarding any namespaces in OccName and extra attributes in Name and
119   -- Var.
120   find (\b -> case b of 
121     Rec l -> False
122     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
123   ) binds
124
125 -- | Processes the given bind as a top level bind.
126 processBind ::
127   CoreBind                        -- The bind to process
128   -> VHDLState ()
129
130 processBind  (Rec _) = error "Recursive binders not supported"
131 processBind bind@(NonRec var expr) = do
132   -- Create the function signature
133   let ty = CoreUtils.exprType expr
134   let hsfunc = mkHsFunction var ty
135   flattenBind hsfunc bind
136
137 -- | Flattens the given bind into the given signature and adds it to the
138 --   session. Then (recursively) finds any functions it uses and does the same
139 --   with them.
140 flattenBind ::
141   HsFunction                         -- The signature to flatten into
142   -> CoreBind                        -- The bind to flatten
143   -> VHDLState ()
144
145 flattenBind _ (Rec _) = error "Recursive binders not supported"
146
147 flattenBind hsfunc bind@(NonRec var expr) = do
148   -- Flatten the function
149   let flatfunc = flattenFunction hsfunc bind
150   addFunc hsfunc
151   setFlatFunc hsfunc flatfunc
152   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc)
153   State.mapM resolvFunc used_hsfuncs
154   return ()
155
156 -- | Find the given function, flatten it and add it to the session. Then
157 --   (recursively) do the same for any functions used.
158 resolvFunc ::
159   HsFunction        -- | The function to look for
160   -> VHDLState ()
161
162 resolvFunc hsfunc = do
163   -- See if the function is already known
164   func <- getFunc hsfunc
165   case func of
166     -- Already known, do nothing
167     Just _ -> do
168       return ()
169     -- New function, resolve it
170     Nothing -> do
171       -- Get the current module
172       core <- getModule
173       -- Find the named function
174       let bind = findBind (cm_binds core) name
175       case bind of
176         Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
177         Just b  -> flattenBind hsfunc b
178   where
179     name = hsFuncName hsfunc
180
181 -- | Translate a top level function declaration to a HsFunction. i.e., which
182 --   interface will be provided by this function. This function essentially
183 --   defines the "calling convention" for hardware models.
184 mkHsFunction ::
185   Var.Var         -- ^ The function defined
186   -> Type         -- ^ The function type (including arguments!)
187   -> HsFunction   -- ^ The resulting HsFunction
188
189 mkHsFunction f ty =
190   HsFunction hsname hsargs hsres
191   where
192     hsname  = getOccString f
193     (arg_tys, res_ty) = Type.splitFunTys ty
194     -- The last argument must be state
195     state_ty = last arg_tys
196     state    = useAsState (mkHsValueMap state_ty)
197     -- All but the last argument are inports
198     inports = map (useAsPort . mkHsValueMap)(init arg_tys)
199     hsargs   = inports ++ [state]
200     hsres    = case splitTupleType res_ty of
201       -- Result type must be a two tuple (state, ports)
202       Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
203         then
204           Tuple [state, useAsPort (mkHsValueMap outport_ty)]
205         else
206           error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
207       otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
208
209 -- | Adds signal names to the given FlatFunction
210 nameFlatFunction ::
211   HsFunction
212   -> FuncData
213   -> VHDLState ()
214
215 nameFlatFunction hsfunc fdata =
216   let func = flatFunc fdata in
217   case func of
218     -- Skip (builtin) functions without a FlatFunction
219     Nothing -> do return ()
220     -- Name the signals in all other functions
221     Just flatfunc ->
222       let s = flat_sigs flatfunc in
223       let s' = map (\(id, (SignalInfo Nothing use ty)) -> (id, SignalInfo (Just $ "sig_" ++ (show id)) use ty)) s in
224       let flatfunc' = flatfunc { flat_sigs = s' } in
225       setFlatFunc hsfunc flatfunc'
226
227 -- | Splits a tuple type into a list of element types, or Nothing if the type
228 --   is not a tuple type.
229 splitTupleType ::
230   Type              -- ^ The type to split
231   -> Maybe [Type]   -- ^ The tuples element types
232
233 splitTupleType ty =
234   case Type.splitTyConApp_maybe ty of
235     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
236       then
237         Just args
238       else
239         Nothing
240     Nothing -> Nothing
241
242 -- | A consise representation of a (set of) ports on a builtin function
243 type PortMap = HsValueMap (String, AST.TypeMark)
244 -- | A consise representation of a builtin function
245 data BuiltIn = BuiltIn String [PortMap] PortMap
246
247 -- | Map a port specification of a builtin function to a VHDL Signal to put in
248 --   a VHDLSignalMap
249 toVHDLSignalMap :: HsValueMap (String, AST.TypeMark) -> VHDLSignalMap
250 toVHDLSignalMap = fmap (\(name, ty) -> Just (VHDL.mkVHDLId name, ty))
251
252 -- | Translate a concise representation of a builtin function to something
253 --   that can be put into FuncMap directly.
254 addBuiltIn :: BuiltIn -> VHDLState ()
255 addBuiltIn (BuiltIn name args res) = do
256     addFunc hsfunc
257     setEntity hsfunc entity
258   where
259     hsfunc = HsFunction name (map useAsPort args) (useAsPort res)
260     entity = Entity (VHDL.mkVHDLId name) (map toVHDLSignalMap args) (toVHDLSignalMap res) Nothing
261
262 builtin_funcs = 
263   [ 
264     BuiltIn "hwxor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
265     BuiltIn "hwand" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
266     BuiltIn "hwor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
267     BuiltIn "hwnot" [(Single ("a", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty))
268   ]
269
270 -- vim: set ts=8 sw=2 sts=2 expandtab: