Run mkHWFunction and addFunc in a State monad.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Main(main) where
2 import GHC
3 import CoreSyn
4 import qualified CoreUtils
5 import qualified Var
6 import qualified Type
7 import qualified TyCon
8 import qualified DataCon
9 import qualified Maybe
10 import qualified Module
11 import qualified Control.Monad.State as State
12 import Name
13 import Data.Generics
14 import NameEnv ( lookupNameEnv )
15 import HscTypes ( cm_binds, cm_types )
16 import MonadUtils ( liftIO )
17 import Outputable ( showSDoc, ppr )
18 import GHC.Paths ( libdir )
19 import DynFlags ( defaultDynFlags )
20 import List ( find )
21 -- The following modules come from the ForSyDe project. They are really
22 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
23 -- ForSyDe to get access to these modules.
24 import qualified ForSyDe.Backend.VHDL.AST as AST
25 import qualified ForSyDe.Backend.VHDL.Ppr
26 import qualified ForSyDe.Backend.Ppr
27 -- This is needed for rendering the pretty printed VHDL
28 import Text.PrettyPrint.HughesPJ (render)
29
30 main = 
31                 do
32                         defaultErrorHandler defaultDynFlags $ do
33                                 runGhc (Just libdir) $ do
34                                         dflags <- getSessionDynFlags
35                                         setSessionDynFlags dflags
36                                         --target <- guessTarget "adder.hs" Nothing
37                                         --liftIO (print (showSDoc (ppr (target))))
38                                         --liftIO $ printTarget target
39                                         --setTargets [target]
40                                         --load LoadAllTargets
41                                         --core <- GHC.compileToCoreSimplified "Adders.hs"
42                                         core <- GHC.compileToCoreSimplified "Adders.hs"
43                                         liftIO $ printBinds (cm_binds core)
44                                         let bind = findBind "half_adder" (cm_binds core)
45                                         let NonRec var expr = bind
46                                         let sess = State.execState (do {(name, f) <- mkHWFunction bind; addFunc name f}) (VHDLSession 0 builtin_funcs)
47                                         liftIO $ putStr $ showSDoc $ ppr expr
48                                         liftIO $ putStr "\n\n"
49                                         liftIO $ putStr $ render $ ForSyDe.Backend.Ppr.ppr $ getArchitecture sess bind
50                                         return expr
51
52 printTarget (Target (TargetFile file (Just x)) obj Nothing) =
53         print $ show file
54
55 printBinds [] = putStr "done\n\n"
56 printBinds (b:bs) = do
57         printBind b
58         putStr "\n"
59         printBinds bs
60
61 printBind (NonRec b expr) = do
62         putStr "NonRec: "
63         printBind' (b, expr)
64
65 printBind (Rec binds) = do
66         putStr "Rec: \n"        
67         foldl1 (>>) (map printBind' binds)
68
69 printBind' (b, expr) = do
70         putStr $ getOccString b
71         --putStr $ showSDoc $ ppr expr
72         putStr "\n"
73
74 findBind :: String -> [CoreBind] -> CoreBind
75 findBind lookfor =
76         -- This ignores Recs and compares the name of the bind with lookfor,
77         -- disregarding any namespaces in OccName and extra attributes in Name and
78         -- Var.
79         Maybe.fromJust . find (\b -> case b of 
80                 Rec l -> False
81                 NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
82         )
83
84 -- Accepts a port name and an argument to map to it.
85 -- Returns the appropriate line for in the port map
86 getPortMapEntry binds (Port portname) (Var id) = 
87         (Just (AST.unsafeVHDLBasicId portname)) AST.:=>: (AST.ADName (AST.NSimple (AST.unsafeVHDLBasicId signalname)))
88         where
89                 Port signalname = Maybe.fromMaybe
90                         (error $ "Argument " ++ getOccString id ++ "is unknown")
91                         (lookup id binds)
92
93 getPortMapEntry binds _ a = error $ "Unsupported argument: " ++ (showSDoc $ ppr a)
94
95 getInstantiations ::
96         VHDLSession
97         -> [PortNameMap]             -- The arguments that need to be applied to the
98                                                                                                                          -- expression.
99         -> PortNameMap               -- The output ports that the expression should generate.
100         -> [(CoreBndr, PortNameMap)] -- A list of bindings in effect
101         -> CoreSyn.CoreExpr          -- The expression to generate an architecture for
102         -> [AST.ConcSm]              -- The resulting VHDL code
103
104 -- A lambda expression binds the first argument (a) to the binder b.
105 getInstantiations sess (a:as) outs binds (Lam b expr) =
106         getInstantiations sess as outs ((b, a):binds) expr
107
108 -- A case expression that checks a single variable and has a single
109 -- alternative, can be used to take tuples apart
110 getInstantiations sess args outs binds (Case (Var v) b _ [res]) =
111         case altcon of
112                 DataAlt datacon ->
113                         if (DataCon.isTupleCon datacon) then
114                                 getInstantiations sess args outs binds' expr
115                         else
116                                 error "Data constructors other than tuples not supported"
117                 otherwise ->
118                         error "Case binders other than tuples not supported"
119         where
120                 binds' = (zip bind_vars tuple_ports) ++ binds
121                 (altcon, bind_vars, expr) = res
122                 -- Find the portnamemaps for each of the tuple's elements
123                 Tuple tuple_ports = Maybe.fromMaybe 
124                         (error $ "Case expression uses unknown scrutinee " ++ getOccString v)
125                         (lookup v binds)
126
127 -- An application is an instantiation of a component
128 getInstantiations sess args outs binds app@(App expr arg) =
129         if isTupleConstructor f then
130                 let
131                         Tuple outports = outs
132                         (tys, vals) = splitTupleConstructorArgs fargs
133                 in
134                         concat $ zipWith 
135                                 (\outs' expr' -> getInstantiations sess args outs' binds expr')
136                                 outports vals
137         else
138                 [AST.CSISm comp]
139         where
140                 ((Var f), fargs) = collectArgs app
141                 comp = AST.CompInsSm
142                         (AST.unsafeVHDLBasicId "app")
143                         (AST.IUEntity (AST.NSimple (AST.unsafeVHDLBasicId compname)))
144                         (AST.PMapAspect ports)
145                 compname = getOccString f
146                 hwfunc = Maybe.fromMaybe
147                         (error $ "Function " ++ compname ++ "is unknown")
148                         (lookup compname (funcs sess))
149                 HWFunction inports outport = hwfunc
150                 ports = 
151                         zipWith (getPortMapEntry binds) inports fargs
152                   ++ mapOutputPorts outport outs
153
154 getInstantiations sess args outs binds expr = 
155         error $ "Unsupported expression" ++ (showSDoc $ ppr $ expr)
156
157 -- Is the given name a (binary) tuple constructor
158 isTupleConstructor :: Var.Var -> Bool
159 isTupleConstructor var =
160         Name.isWiredInName name
161         && Name.nameModule name == tuple_mod
162         && (Name.occNameString $ Name.nameOccName name) == "(,)"
163         where
164                 name = Var.varName var
165                 mod = nameModule name
166                 tuple_mod = Module.mkModule (Module.stringToPackageId "ghc-prim") (Module.mkModuleName "GHC.Tuple")
167
168 -- Split arguments into type arguments and value arguments This is probably
169 -- not really sufficient (not sure if Types can actually occur as value
170 -- arguments...)
171 splitTupleConstructorArgs :: [CoreExpr] -> ([CoreExpr], [CoreExpr])
172 splitTupleConstructorArgs (e:es) =
173         case e of
174                 Type t     -> (e:tys, vals)
175                 otherwise  -> (tys, e:vals)
176         where
177                 (tys, vals) = splitTupleConstructorArgs es
178
179 mapOutputPorts ::
180         PortNameMap         -- The output portnames of the component
181         -> PortNameMap      -- The output portnames and/or signals to map these to
182         -> [AST.AssocElem]  -- The resulting output ports
183
184 -- Map the output port of a component to the output port of the containing
185 -- entity.
186 mapOutputPorts (Port portname) (Port signalname) =
187         [(Just (AST.unsafeVHDLBasicId portname)) AST.:=>: (AST.ADName (AST.NSimple (AST.unsafeVHDLBasicId signalname)))]
188
189 -- Map matching output ports in the tuple
190 mapOutputPorts (Tuple ports) (Tuple signals) =
191         concat (zipWith mapOutputPorts ports signals)
192
193 getArchitecture ::
194         VHDLSession
195         -> CoreBind               -- The binder to expand into an architecture
196         -> AST.ArchBody           -- The resulting architecture
197          
198 getArchitecture sess (Rec _) = error "Recursive binders not supported"
199
200 getArchitecture sess (NonRec var expr) =
201         AST.ArchBody
202                 (AST.unsafeVHDLBasicId "structural")
203                 -- Use unsafe for now, to prevent pulling in ForSyDe error handling
204                 (AST.NSimple (AST.unsafeVHDLBasicId name))
205                 []
206                 (getInstantiations sess inports outport [] expr)
207         where
208                 name = (getOccString var)
209                 hwfunc = Maybe.fromMaybe
210                         (error $ "Function " ++ name ++ "is unknown? This should not happen!")
211                         (lookup name (funcs sess))
212                 HWFunction inports outport = hwfunc
213
214 data PortNameMap =
215         Tuple [PortNameMap]
216         | Port  String
217   deriving (Show)
218
219 -- Generate a port name map (or multiple for tuple types) in the given direction for
220 -- each type given.
221 getPortNameMapForTys :: String -> Int -> [Type] -> [PortNameMap]
222 getPortNameMapForTys prefix num [] = [] 
223 getPortNameMapForTys prefix num (t:ts) =
224         (getPortNameMapForTy (prefix ++ show num) t) : getPortNameMapForTys prefix (num + 1) ts
225
226 getPortNameMapForTy     :: String -> Type -> PortNameMap
227 getPortNameMapForTy name ty =
228         if (TyCon.isTupleTyCon tycon) then
229                 -- Expand tuples we find
230                 Tuple (getPortNameMapForTys name 0 args)
231         else -- Assume it's a type constructor application, ie simple data type
232                 -- TODO: Add type?
233                 Port name
234         where
235                 (tycon, args) = Type.splitTyConApp ty 
236
237 data HWFunction = HWFunction { -- A function that is available in hardware
238         inPorts   :: [PortNameMap],
239         outPort   :: PortNameMap
240         --entity    :: AST.EntityDec
241 } deriving (Show)
242
243 -- Turns a CoreExpr describing a function into a description of its input and
244 -- output ports.
245 mkHWFunction ::
246         CoreBind                                   -- The core binder to generate the interface for
247         -> VHDLState (String, HWFunction)          -- The name of the function and its interface
248
249 mkHWFunction (NonRec var expr) =
250                 return (name, HWFunction inports outport)
251         where
252                 name = (getOccString var)
253                 ty = CoreUtils.exprType expr
254                 (fargs, res) = Type.splitFunTys ty
255                 args = if length fargs == 1 then fargs else (init fargs)
256                 --state = if length fargs == 1 then () else (last fargs)
257                 inports = case args of
258                         -- Handle a single port specially, to prevent an extra 0 in the name
259                         [port] -> [getPortNameMapForTy "portin" port]
260                         ps     -> getPortNameMapForTys "portin" 0 ps
261                 outport = getPortNameMapForTy "portout" res
262
263 mkHWFunction (Rec _) =
264         error "Recursive binders not supported"
265
266 data VHDLSession = VHDLSession {
267         nameCount :: Int,                      -- A counter that can be used to generate unique names
268         funcs     :: [(String, HWFunction)]    -- All functions available, indexed by name
269 } deriving (Show)
270
271 type VHDLState = State.State VHDLSession
272
273 -- Add the function to the session
274 addFunc :: String -> HWFunction -> VHDLState ()
275 addFunc name f = do
276         fs <- State.gets funcs -- Get the funcs element form the session
277         State.modify (\x -> x {funcs = (name, f) : fs }) -- Prepend name and f
278
279 builtin_funcs = 
280         [ 
281                 ("hwxor", HWFunction [Port "a", Port "b"] (Port "o")),
282                 ("hwand", HWFunction [Port "a", Port "b"] (Port "o"))
283         ]