Use a different approach for marking SigUses.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Translator where
2 import GHC hiding (loadModule)
3 import CoreSyn
4 import qualified CoreUtils
5 import qualified Var
6 import qualified Type
7 import qualified TyCon
8 import qualified DataCon
9 import qualified Maybe
10 import qualified Module
11 import qualified Control.Monad.State as State
12 import Name
13 import qualified Data.Map as Map
14 import Data.Generics
15 import NameEnv ( lookupNameEnv )
16 import qualified HscTypes
17 import HscTypes ( cm_binds, cm_types )
18 import MonadUtils ( liftIO )
19 import Outputable ( showSDoc, ppr )
20 import GHC.Paths ( libdir )
21 import DynFlags ( defaultDynFlags )
22 import List ( find )
23 import qualified List
24 import qualified Monad
25
26 -- The following modules come from the ForSyDe project. They are really
27 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
28 -- ForSyDe to get access to these modules.
29 import qualified ForSyDe.Backend.VHDL.AST as AST
30 import qualified ForSyDe.Backend.VHDL.Ppr
31 import qualified ForSyDe.Backend.VHDL.FileIO
32 import qualified ForSyDe.Backend.Ppr
33 -- This is needed for rendering the pretty printed VHDL
34 import Text.PrettyPrint.HughesPJ (render)
35
36 import TranslatorTypes
37 import HsValueMap
38 import Pretty
39 import Flatten
40 import FlattenTypes
41 import VHDLTypes
42 import qualified VHDL
43
44 main = do
45   -- Load the module
46   core <- loadModule "Adders.hs"
47   -- Translate to VHDL
48   vhdl <- moduleToVHDL core ["dff"]
49   -- Write VHDL to file
50   writeVHDL vhdl "../vhdl/vhdl/output.vhdl"
51
52 -- | Show the core structure of the given binds in the given file.
53 listBind :: String -> String -> IO ()
54 listBind filename name = do
55   core <- loadModule filename
56   let binds = findBinds core [name]
57   putStr "\n"
58   putStr $ prettyShow binds
59   putStr "\n\n"
60
61 -- | Translate the binds with the given names from the given core module to
62 --   VHDL
63 moduleToVHDL :: HscTypes.CoreModule -> [String] -> IO AST.DesignFile
64 moduleToVHDL core names = do
65   --liftIO $ putStr $ prettyShow (cm_binds core)
66   let binds = findBinds core names
67   --putStr $ prettyShow binds
68   -- Turn bind into VHDL
69   let (vhdl, sess) = State.runState (mkVHDL binds) (VHDLSession core 0 Map.empty)
70   putStr $ render $ ForSyDe.Backend.Ppr.ppr vhdl
71   putStr $ "\n\nFinal session:\n" ++ prettyShow sess ++ "\n\n"
72   return vhdl
73
74   where
75     -- Turns the given bind into VHDL
76     mkVHDL binds = do
77       -- Add the builtin functions
78       mapM addBuiltIn builtin_funcs
79       -- Create entities and architectures for them
80       mapM processBind binds
81       modFuncs nameFlatFunction
82       modFuncs VHDL.createEntity
83       modFuncs VHDL.createArchitecture
84       VHDL.getDesignFile
85
86 -- | Write the given design file to the given file
87 writeVHDL :: AST.DesignFile -> String -> IO ()
88 writeVHDL = ForSyDe.Backend.VHDL.FileIO.writeDesignFile
89
90 -- | Loads the given file and turns it into a core module.
91 loadModule :: String -> IO HscTypes.CoreModule
92 loadModule filename =
93   defaultErrorHandler defaultDynFlags $ do
94     runGhc (Just libdir) $ do
95       dflags <- getSessionDynFlags
96       setSessionDynFlags dflags
97       --target <- guessTarget "adder.hs" Nothing
98       --liftIO (print (showSDoc (ppr (target))))
99       --liftIO $ printTarget target
100       --setTargets [target]
101       --load LoadAllTargets
102       --core <- GHC.compileToCoreSimplified "Adders.hs"
103       core <- GHC.compileToCoreSimplified filename
104       return core
105
106 -- | Extracts the named binds from the given module.
107 findBinds :: HscTypes.CoreModule -> [String] -> [CoreBind]
108 findBinds core names = Maybe.mapMaybe (findBind (cm_binds core)) names
109
110 -- | Extract a named bind from the given list of binds
111 findBind :: [CoreBind] -> String -> Maybe CoreBind
112 findBind binds lookfor =
113   -- This ignores Recs and compares the name of the bind with lookfor,
114   -- disregarding any namespaces in OccName and extra attributes in Name and
115   -- Var.
116   find (\b -> case b of 
117     Rec l -> False
118     NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
119   ) binds
120
121 -- | Processes the given bind as a top level bind.
122 processBind ::
123   CoreBind                        -- The bind to process
124   -> VHDLState ()
125
126 processBind  (Rec _) = error "Recursive binders not supported"
127 processBind bind@(NonRec var expr) = do
128   -- Create the function signature
129   let ty = CoreUtils.exprType expr
130   let hsfunc = mkHsFunction var ty
131   flattenBind hsfunc bind
132
133 -- | Flattens the given bind into the given signature and adds it to the
134 --   session. Then (recursively) finds any functions it uses and does the same
135 --   with them.
136 flattenBind ::
137   HsFunction                         -- The signature to flatten into
138   -> CoreBind                        -- The bind to flatten
139   -> VHDLState ()
140
141 flattenBind _ (Rec _) = error "Recursive binders not supported"
142
143 flattenBind hsfunc bind@(NonRec var expr) = do
144   -- Flatten the function
145   let flatfunc = flattenFunction hsfunc bind
146   addFunc hsfunc
147   setFlatFunc hsfunc flatfunc
148   let used_hsfuncs = Maybe.mapMaybe usedHsFunc (flat_defs flatfunc)
149   State.mapM resolvFunc used_hsfuncs
150   return ()
151
152 -- | Find the given function, flatten it and add it to the session. Then
153 --   (recursively) do the same for any functions used.
154 resolvFunc ::
155   HsFunction        -- | The function to look for
156   -> VHDLState ()
157
158 resolvFunc hsfunc = do
159   -- See if the function is already known
160   func <- getFunc hsfunc
161   case func of
162     -- Already known, do nothing
163     Just _ -> do
164       return ()
165     -- New function, resolve it
166     Nothing -> do
167       -- Get the current module
168       core <- getModule
169       -- Find the named function
170       let bind = findBind (cm_binds core) name
171       case bind of
172         Nothing -> error $ "Couldn't find function " ++ name ++ " in current module."
173         Just b  -> flattenBind hsfunc b
174   where
175     name = hsFuncName hsfunc
176
177 -- | Translate a top level function declaration to a HsFunction. i.e., which
178 --   interface will be provided by this function. This function essentially
179 --   defines the "calling convention" for hardware models.
180 mkHsFunction ::
181   Var.Var         -- ^ The function defined
182   -> Type         -- ^ The function type (including arguments!)
183   -> HsFunction   -- ^ The resulting HsFunction
184
185 mkHsFunction f ty =
186   HsFunction hsname hsargs hsres
187   where
188     hsname  = getOccString f
189     (arg_tys, res_ty) = Type.splitFunTys ty
190     -- The last argument must be state
191     state_ty = last arg_tys
192     state    = useAsState (mkHsValueMap state_ty)
193     -- All but the last argument are inports
194     inports = map (useAsPort . mkHsValueMap)(init arg_tys)
195     hsargs   = inports ++ [state]
196     hsres    = case splitTupleType res_ty of
197       -- Result type must be a two tuple (state, ports)
198       Just [outstate_ty, outport_ty] -> if Type.coreEqType state_ty outstate_ty
199         then
200           Tuple [state, useAsPort (mkHsValueMap outport_ty)]
201         else
202           error $ "Input state type of function " ++ hsname ++ ": " ++ (showSDoc $ ppr state_ty) ++ " does not match output state type: " ++ (showSDoc $ ppr outstate_ty)
203       otherwise                -> error $ "Return type of top-level function " ++ hsname ++ " must be a two-tuple containing a state and output ports."
204
205 -- | Adds signal names to the given FlatFunction
206 nameFlatFunction ::
207   HsFunction
208   -> FuncData
209   -> VHDLState ()
210
211 nameFlatFunction hsfunc fdata =
212   let func = flatFunc fdata in
213   case func of
214     -- Skip (builtin) functions without a FlatFunction
215     Nothing -> do return ()
216     -- Name the signals in all other functions
217     Just flatfunc ->
218       let s = flat_sigs flatfunc in
219       let s' = map (\(id, (SignalInfo Nothing use ty)) -> (id, SignalInfo (Just $ "sig_" ++ (show id)) use ty)) s in
220       let flatfunc' = flatfunc { flat_sigs = s' } in
221       setFlatFunc hsfunc flatfunc'
222
223 -- | Splits a tuple type into a list of element types, or Nothing if the type
224 --   is not a tuple type.
225 splitTupleType ::
226   Type              -- ^ The type to split
227   -> Maybe [Type]   -- ^ The tuples element types
228
229 splitTupleType ty =
230   case Type.splitTyConApp_maybe ty of
231     Just (tycon, args) -> if TyCon.isTupleTyCon tycon 
232       then
233         Just args
234       else
235         Nothing
236     Nothing -> Nothing
237
238 -- | A consise representation of a (set of) ports on a builtin function
239 type PortMap = HsValueMap (String, AST.TypeMark)
240 -- | A consise representation of a builtin function
241 data BuiltIn = BuiltIn String [PortMap] PortMap
242
243 -- | Map a port specification of a builtin function to a VHDL Signal to put in
244 --   a VHDLSignalMap
245 toVHDLSignalMap :: HsValueMap (String, AST.TypeMark) -> VHDLSignalMap
246 toVHDLSignalMap = fmap (\(name, ty) -> Just (VHDL.mkVHDLId name, ty))
247
248 -- | Translate a concise representation of a builtin function to something
249 --   that can be put into FuncMap directly.
250 addBuiltIn :: BuiltIn -> VHDLState ()
251 addBuiltIn (BuiltIn name args res) = do
252     addFunc hsfunc
253     setEntity hsfunc entity
254   where
255     hsfunc = HsFunction name (map useAsPort args) (useAsPort res)
256     entity = Entity (VHDL.mkVHDLId name) (map toVHDLSignalMap args) (toVHDLSignalMap res) Nothing
257
258 builtin_funcs = 
259   [ 
260     BuiltIn "hwxor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
261     BuiltIn "hwand" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
262     BuiltIn "hwor" [(Single ("a", VHDL.bit_ty)), (Single ("b", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty)),
263     BuiltIn "hwnot" [(Single ("a", VHDL.bit_ty))] (Single ("o", VHDL.bit_ty))
264   ]
265
266 -- vim: set ts=8 sw=2 sts=2 expandtab: