Put getArchitecture inside the State monad.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Main(main) where
2 import GHC
3 import CoreSyn
4 import qualified CoreUtils
5 import qualified Var
6 import qualified Type
7 import qualified TyCon
8 import qualified DataCon
9 import qualified Maybe
10 import qualified Module
11 import qualified Control.Monad.State as State
12 import Name
13 import Data.Generics
14 import NameEnv ( lookupNameEnv )
15 import HscTypes ( cm_binds, cm_types )
16 import MonadUtils ( liftIO )
17 import Outputable ( showSDoc, ppr )
18 import GHC.Paths ( libdir )
19 import DynFlags ( defaultDynFlags )
20 import List ( find )
21 -- The following modules come from the ForSyDe project. They are really
22 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
23 -- ForSyDe to get access to these modules.
24 import qualified ForSyDe.Backend.VHDL.AST as AST
25 import qualified ForSyDe.Backend.VHDL.Ppr
26 import qualified ForSyDe.Backend.Ppr
27 -- This is needed for rendering the pretty printed VHDL
28 import Text.PrettyPrint.HughesPJ (render)
29
30 main = 
31                 do
32                         defaultErrorHandler defaultDynFlags $ do
33                                 runGhc (Just libdir) $ do
34                                         dflags <- getSessionDynFlags
35                                         setSessionDynFlags dflags
36                                         --target <- guessTarget "adder.hs" Nothing
37                                         --liftIO (print (showSDoc (ppr (target))))
38                                         --liftIO $ printTarget target
39                                         --setTargets [target]
40                                         --load LoadAllTargets
41                                         --core <- GHC.compileToCoreSimplified "Adders.hs"
42                                         core <- GHC.compileToCoreSimplified "Adders.hs"
43                                         liftIO $ printBinds (cm_binds core)
44                                         let bind = findBind "half_adder" (cm_binds core)
45                                         let NonRec var expr = bind
46                                         -- Turn bind into VHDL
47                                         let vhdl = State.evalState (mkVHDL bind) (VHDLSession 0 builtin_funcs)
48                                         liftIO $ putStr $ showSDoc $ ppr expr
49                                         liftIO $ putStr "\n\n"
50                                         liftIO $ putStr $ render $ ForSyDe.Backend.Ppr.ppr $ vhdl
51                                         return expr
52         where
53                 -- Turns the given bind into VHDL
54                 mkVHDL bind = do
55                         -- Get the function signature
56                         (name, f) <- mkHWFunction bind
57                         -- Add it to the session
58                         addFunc name f
59                         arch <- getArchitecture bind
60                         return arch
61
62 printTarget (Target (TargetFile file (Just x)) obj Nothing) =
63         print $ show file
64
65 printBinds [] = putStr "done\n\n"
66 printBinds (b:bs) = do
67         printBind b
68         putStr "\n"
69         printBinds bs
70
71 printBind (NonRec b expr) = do
72         putStr "NonRec: "
73         printBind' (b, expr)
74
75 printBind (Rec binds) = do
76         putStr "Rec: \n"        
77         foldl1 (>>) (map printBind' binds)
78
79 printBind' (b, expr) = do
80         putStr $ getOccString b
81         --putStr $ showSDoc $ ppr expr
82         putStr "\n"
83
84 findBind :: String -> [CoreBind] -> CoreBind
85 findBind lookfor =
86         -- This ignores Recs and compares the name of the bind with lookfor,
87         -- disregarding any namespaces in OccName and extra attributes in Name and
88         -- Var.
89         Maybe.fromJust . find (\b -> case b of 
90                 Rec l -> False
91                 NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
92         )
93
94 -- Accepts a port name and an argument to map to it.
95 -- Returns the appropriate line for in the port map
96 getPortMapEntry binds (Port portname) (Var id) = 
97         (Just (AST.unsafeVHDLBasicId portname)) AST.:=>: (AST.ADName (AST.NSimple (AST.unsafeVHDLBasicId signalname)))
98         where
99                 Port signalname = Maybe.fromMaybe
100                         (error $ "Argument " ++ getOccString id ++ "is unknown")
101                         (lookup id binds)
102
103 getPortMapEntry binds _ a = error $ "Unsupported argument: " ++ (showSDoc $ ppr a)
104
105 getInstantiations ::
106         VHDLSession
107         -> [PortNameMap]             -- The arguments that need to be applied to the
108                                                                                                                          -- expression.
109         -> PortNameMap               -- The output ports that the expression should generate.
110         -> [(CoreBndr, PortNameMap)] -- A list of bindings in effect
111         -> CoreSyn.CoreExpr          -- The expression to generate an architecture for
112         -> [AST.ConcSm]              -- The resulting VHDL code
113
114 -- A lambda expression binds the first argument (a) to the binder b.
115 getInstantiations sess (a:as) outs binds (Lam b expr) =
116         getInstantiations sess as outs ((b, a):binds) expr
117
118 -- A case expression that checks a single variable and has a single
119 -- alternative, can be used to take tuples apart
120 getInstantiations sess args outs binds (Case (Var v) b _ [res]) =
121         case altcon of
122                 DataAlt datacon ->
123                         if (DataCon.isTupleCon datacon) then
124                                 getInstantiations sess args outs binds' expr
125                         else
126                                 error "Data constructors other than tuples not supported"
127                 otherwise ->
128                         error "Case binders other than tuples not supported"
129         where
130                 binds' = (zip bind_vars tuple_ports) ++ binds
131                 (altcon, bind_vars, expr) = res
132                 -- Find the portnamemaps for each of the tuple's elements
133                 Tuple tuple_ports = Maybe.fromMaybe 
134                         (error $ "Case expression uses unknown scrutinee " ++ getOccString v)
135                         (lookup v binds)
136
137 -- An application is an instantiation of a component
138 getInstantiations sess args outs binds app@(App expr arg) =
139         if isTupleConstructor f then
140                 let
141                         Tuple outports = outs
142                         (tys, vals) = splitTupleConstructorArgs fargs
143                 in
144                         concat $ zipWith 
145                                 (\outs' expr' -> getInstantiations sess args outs' binds expr')
146                                 outports vals
147         else
148                 [AST.CSISm comp]
149         where
150                 ((Var f), fargs) = collectArgs app
151                 comp = AST.CompInsSm
152                         (AST.unsafeVHDLBasicId "app")
153                         (AST.IUEntity (AST.NSimple (AST.unsafeVHDLBasicId compname)))
154                         (AST.PMapAspect ports)
155                 compname = getOccString f
156                 hwfunc = Maybe.fromMaybe
157                         (error $ "Function " ++ compname ++ "is unknown")
158                         (lookup compname (funcs sess))
159                 HWFunction inports outport = hwfunc
160                 ports = 
161                         zipWith (getPortMapEntry binds) inports fargs
162                   ++ mapOutputPorts outport outs
163
164 getInstantiations sess args outs binds expr = 
165         error $ "Unsupported expression" ++ (showSDoc $ ppr $ expr)
166
167 -- Is the given name a (binary) tuple constructor
168 isTupleConstructor :: Var.Var -> Bool
169 isTupleConstructor var =
170         Name.isWiredInName name
171         && Name.nameModule name == tuple_mod
172         && (Name.occNameString $ Name.nameOccName name) == "(,)"
173         where
174                 name = Var.varName var
175                 mod = nameModule name
176                 tuple_mod = Module.mkModule (Module.stringToPackageId "ghc-prim") (Module.mkModuleName "GHC.Tuple")
177
178 -- Split arguments into type arguments and value arguments This is probably
179 -- not really sufficient (not sure if Types can actually occur as value
180 -- arguments...)
181 splitTupleConstructorArgs :: [CoreExpr] -> ([CoreExpr], [CoreExpr])
182 splitTupleConstructorArgs (e:es) =
183         case e of
184                 Type t     -> (e:tys, vals)
185                 otherwise  -> (tys, e:vals)
186         where
187                 (tys, vals) = splitTupleConstructorArgs es
188
189 mapOutputPorts ::
190         PortNameMap         -- The output portnames of the component
191         -> PortNameMap      -- The output portnames and/or signals to map these to
192         -> [AST.AssocElem]  -- The resulting output ports
193
194 -- Map the output port of a component to the output port of the containing
195 -- entity.
196 mapOutputPorts (Port portname) (Port signalname) =
197         [(Just (AST.unsafeVHDLBasicId portname)) AST.:=>: (AST.ADName (AST.NSimple (AST.unsafeVHDLBasicId signalname)))]
198
199 -- Map matching output ports in the tuple
200 mapOutputPorts (Tuple ports) (Tuple signals) =
201         concat (zipWith mapOutputPorts ports signals)
202
203 getArchitecture ::
204         CoreBind                  -- The binder to expand into an architecture
205         -> VHDLState AST.ArchBody -- The resulting architecture
206          
207 getArchitecture (Rec _) = error "Recursive binders not supported"
208
209 getArchitecture (NonRec var expr) = do
210         HWFunction inports outport <- getHWFunc name
211         sess <- State.get
212         return $ AST.ArchBody
213                 (AST.unsafeVHDLBasicId "structural")
214                 -- Use unsafe for now, to prevent pulling in ForSyDe error handling
215                 (AST.NSimple (AST.unsafeVHDLBasicId name))
216                 []
217                 (getInstantiations sess inports outport [] expr)
218         where
219                 name = (getOccString var)
220
221 data PortNameMap =
222         Tuple [PortNameMap]
223         | Port  String
224   deriving (Show)
225
226 -- Generate a port name map (or multiple for tuple types) in the given direction for
227 -- each type given.
228 getPortNameMapForTys :: String -> Int -> [Type] -> [PortNameMap]
229 getPortNameMapForTys prefix num [] = [] 
230 getPortNameMapForTys prefix num (t:ts) =
231         (getPortNameMapForTy (prefix ++ show num) t) : getPortNameMapForTys prefix (num + 1) ts
232
233 getPortNameMapForTy     :: String -> Type -> PortNameMap
234 getPortNameMapForTy name ty =
235         if (TyCon.isTupleTyCon tycon) then
236                 -- Expand tuples we find
237                 Tuple (getPortNameMapForTys name 0 args)
238         else -- Assume it's a type constructor application, ie simple data type
239                 -- TODO: Add type?
240                 Port name
241         where
242                 (tycon, args) = Type.splitTyConApp ty 
243
244 data HWFunction = HWFunction { -- A function that is available in hardware
245         inPorts   :: [PortNameMap],
246         outPort   :: PortNameMap
247         --entity    :: AST.EntityDec
248 } deriving (Show)
249
250 -- Turns a CoreExpr describing a function into a description of its input and
251 -- output ports.
252 mkHWFunction ::
253         CoreBind                                   -- The core binder to generate the interface for
254         -> VHDLState (String, HWFunction)          -- The name of the function and its interface
255
256 mkHWFunction (NonRec var expr) =
257                 return (name, HWFunction inports outport)
258         where
259                 name = (getOccString var)
260                 ty = CoreUtils.exprType expr
261                 (fargs, res) = Type.splitFunTys ty
262                 args = if length fargs == 1 then fargs else (init fargs)
263                 --state = if length fargs == 1 then () else (last fargs)
264                 inports = case args of
265                         -- Handle a single port specially, to prevent an extra 0 in the name
266                         [port] -> [getPortNameMapForTy "portin" port]
267                         ps     -> getPortNameMapForTys "portin" 0 ps
268                 outport = getPortNameMapForTy "portout" res
269
270 mkHWFunction (Rec _) =
271         error "Recursive binders not supported"
272
273 data VHDLSession = VHDLSession {
274         nameCount :: Int,                      -- A counter that can be used to generate unique names
275         funcs     :: [(String, HWFunction)]    -- All functions available, indexed by name
276 } deriving (Show)
277
278 type VHDLState = State.State VHDLSession
279
280 -- Add the function to the session
281 addFunc :: String -> HWFunction -> VHDLState ()
282 addFunc name f = do
283         fs <- State.gets funcs -- Get the funcs element from the session
284         State.modify (\x -> x {funcs = (name, f) : fs }) -- Prepend name and f
285
286 -- Lookup the function with the given name in the current session. Errors if
287 -- it was not found.
288 getHWFunc :: String -> VHDLState HWFunction
289 getHWFunc name = do
290         fs <- State.gets funcs -- Get the funcs element from the session
291         return $ Maybe.fromMaybe
292                 (error $ "Function " ++ name ++ "is unknown? This should not happen!")
293                 (lookup name fs)
294
295 builtin_funcs = 
296         [ 
297                 ("hwxor", HWFunction [Port "a", Port "b"] (Port "o")),
298                 ("hwand", HWFunction [Port "a", Port "b"] (Port "o"))
299         ]