Make the state monad calling code more pretty.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Main(main) where
2 import GHC
3 import CoreSyn
4 import qualified CoreUtils
5 import qualified Var
6 import qualified Type
7 import qualified TyCon
8 import qualified DataCon
9 import qualified Maybe
10 import qualified Module
11 import qualified Control.Monad.State as State
12 import Name
13 import Data.Generics
14 import NameEnv ( lookupNameEnv )
15 import HscTypes ( cm_binds, cm_types )
16 import MonadUtils ( liftIO )
17 import Outputable ( showSDoc, ppr )
18 import GHC.Paths ( libdir )
19 import DynFlags ( defaultDynFlags )
20 import List ( find )
21 -- The following modules come from the ForSyDe project. They are really
22 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
23 -- ForSyDe to get access to these modules.
24 import qualified ForSyDe.Backend.VHDL.AST as AST
25 import qualified ForSyDe.Backend.VHDL.Ppr
26 import qualified ForSyDe.Backend.Ppr
27 -- This is needed for rendering the pretty printed VHDL
28 import Text.PrettyPrint.HughesPJ (render)
29
30 main = 
31                 do
32                         defaultErrorHandler defaultDynFlags $ do
33                                 runGhc (Just libdir) $ do
34                                         dflags <- getSessionDynFlags
35                                         setSessionDynFlags dflags
36                                         --target <- guessTarget "adder.hs" Nothing
37                                         --liftIO (print (showSDoc (ppr (target))))
38                                         --liftIO $ printTarget target
39                                         --setTargets [target]
40                                         --load LoadAllTargets
41                                         --core <- GHC.compileToCoreSimplified "Adders.hs"
42                                         core <- GHC.compileToCoreSimplified "Adders.hs"
43                                         liftIO $ printBinds (cm_binds core)
44                                         let bind = findBind "half_adder" (cm_binds core)
45                                         let NonRec var expr = bind
46                                         -- Add the HWFunction from the bind to the session
47                                         let sess = State.execState (addF bind) (VHDLSession 0 builtin_funcs)
48                                         liftIO $ putStr $ showSDoc $ ppr expr
49                                         liftIO $ putStr "\n\n"
50                                         liftIO $ putStr $ render $ ForSyDe.Backend.Ppr.ppr $ getArchitecture sess bind
51                                         return expr
52         where
53                 -- Turns the given bind into VHDL
54                 addF bind = do
55                         -- Get the function signature
56                         (name, f) <- mkHWFunction bind
57                         -- Add it to the session
58                         addFunc name f
59
60 printTarget (Target (TargetFile file (Just x)) obj Nothing) =
61         print $ show file
62
63 printBinds [] = putStr "done\n\n"
64 printBinds (b:bs) = do
65         printBind b
66         putStr "\n"
67         printBinds bs
68
69 printBind (NonRec b expr) = do
70         putStr "NonRec: "
71         printBind' (b, expr)
72
73 printBind (Rec binds) = do
74         putStr "Rec: \n"        
75         foldl1 (>>) (map printBind' binds)
76
77 printBind' (b, expr) = do
78         putStr $ getOccString b
79         --putStr $ showSDoc $ ppr expr
80         putStr "\n"
81
82 findBind :: String -> [CoreBind] -> CoreBind
83 findBind lookfor =
84         -- This ignores Recs and compares the name of the bind with lookfor,
85         -- disregarding any namespaces in OccName and extra attributes in Name and
86         -- Var.
87         Maybe.fromJust . find (\b -> case b of 
88                 Rec l -> False
89                 NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
90         )
91
92 -- Accepts a port name and an argument to map to it.
93 -- Returns the appropriate line for in the port map
94 getPortMapEntry binds (Port portname) (Var id) = 
95         (Just (AST.unsafeVHDLBasicId portname)) AST.:=>: (AST.ADName (AST.NSimple (AST.unsafeVHDLBasicId signalname)))
96         where
97                 Port signalname = Maybe.fromMaybe
98                         (error $ "Argument " ++ getOccString id ++ "is unknown")
99                         (lookup id binds)
100
101 getPortMapEntry binds _ a = error $ "Unsupported argument: " ++ (showSDoc $ ppr a)
102
103 getInstantiations ::
104         VHDLSession
105         -> [PortNameMap]             -- The arguments that need to be applied to the
106                                                                                                                          -- expression.
107         -> PortNameMap               -- The output ports that the expression should generate.
108         -> [(CoreBndr, PortNameMap)] -- A list of bindings in effect
109         -> CoreSyn.CoreExpr          -- The expression to generate an architecture for
110         -> [AST.ConcSm]              -- The resulting VHDL code
111
112 -- A lambda expression binds the first argument (a) to the binder b.
113 getInstantiations sess (a:as) outs binds (Lam b expr) =
114         getInstantiations sess as outs ((b, a):binds) expr
115
116 -- A case expression that checks a single variable and has a single
117 -- alternative, can be used to take tuples apart
118 getInstantiations sess args outs binds (Case (Var v) b _ [res]) =
119         case altcon of
120                 DataAlt datacon ->
121                         if (DataCon.isTupleCon datacon) then
122                                 getInstantiations sess args outs binds' expr
123                         else
124                                 error "Data constructors other than tuples not supported"
125                 otherwise ->
126                         error "Case binders other than tuples not supported"
127         where
128                 binds' = (zip bind_vars tuple_ports) ++ binds
129                 (altcon, bind_vars, expr) = res
130                 -- Find the portnamemaps for each of the tuple's elements
131                 Tuple tuple_ports = Maybe.fromMaybe 
132                         (error $ "Case expression uses unknown scrutinee " ++ getOccString v)
133                         (lookup v binds)
134
135 -- An application is an instantiation of a component
136 getInstantiations sess args outs binds app@(App expr arg) =
137         if isTupleConstructor f then
138                 let
139                         Tuple outports = outs
140                         (tys, vals) = splitTupleConstructorArgs fargs
141                 in
142                         concat $ zipWith 
143                                 (\outs' expr' -> getInstantiations sess args outs' binds expr')
144                                 outports vals
145         else
146                 [AST.CSISm comp]
147         where
148                 ((Var f), fargs) = collectArgs app
149                 comp = AST.CompInsSm
150                         (AST.unsafeVHDLBasicId "app")
151                         (AST.IUEntity (AST.NSimple (AST.unsafeVHDLBasicId compname)))
152                         (AST.PMapAspect ports)
153                 compname = getOccString f
154                 hwfunc = Maybe.fromMaybe
155                         (error $ "Function " ++ compname ++ "is unknown")
156                         (lookup compname (funcs sess))
157                 HWFunction inports outport = hwfunc
158                 ports = 
159                         zipWith (getPortMapEntry binds) inports fargs
160                   ++ mapOutputPorts outport outs
161
162 getInstantiations sess args outs binds expr = 
163         error $ "Unsupported expression" ++ (showSDoc $ ppr $ expr)
164
165 -- Is the given name a (binary) tuple constructor
166 isTupleConstructor :: Var.Var -> Bool
167 isTupleConstructor var =
168         Name.isWiredInName name
169         && Name.nameModule name == tuple_mod
170         && (Name.occNameString $ Name.nameOccName name) == "(,)"
171         where
172                 name = Var.varName var
173                 mod = nameModule name
174                 tuple_mod = Module.mkModule (Module.stringToPackageId "ghc-prim") (Module.mkModuleName "GHC.Tuple")
175
176 -- Split arguments into type arguments and value arguments This is probably
177 -- not really sufficient (not sure if Types can actually occur as value
178 -- arguments...)
179 splitTupleConstructorArgs :: [CoreExpr] -> ([CoreExpr], [CoreExpr])
180 splitTupleConstructorArgs (e:es) =
181         case e of
182                 Type t     -> (e:tys, vals)
183                 otherwise  -> (tys, e:vals)
184         where
185                 (tys, vals) = splitTupleConstructorArgs es
186
187 mapOutputPorts ::
188         PortNameMap         -- The output portnames of the component
189         -> PortNameMap      -- The output portnames and/or signals to map these to
190         -> [AST.AssocElem]  -- The resulting output ports
191
192 -- Map the output port of a component to the output port of the containing
193 -- entity.
194 mapOutputPorts (Port portname) (Port signalname) =
195         [(Just (AST.unsafeVHDLBasicId portname)) AST.:=>: (AST.ADName (AST.NSimple (AST.unsafeVHDLBasicId signalname)))]
196
197 -- Map matching output ports in the tuple
198 mapOutputPorts (Tuple ports) (Tuple signals) =
199         concat (zipWith mapOutputPorts ports signals)
200
201 getArchitecture ::
202         VHDLSession
203         -> CoreBind               -- The binder to expand into an architecture
204         -> AST.ArchBody           -- The resulting architecture
205          
206 getArchitecture sess (Rec _) = error "Recursive binders not supported"
207
208 getArchitecture sess (NonRec var expr) =
209         AST.ArchBody
210                 (AST.unsafeVHDLBasicId "structural")
211                 -- Use unsafe for now, to prevent pulling in ForSyDe error handling
212                 (AST.NSimple (AST.unsafeVHDLBasicId name))
213                 []
214                 (getInstantiations sess inports outport [] expr)
215         where
216                 name = (getOccString var)
217                 hwfunc = Maybe.fromMaybe
218                         (error $ "Function " ++ name ++ "is unknown? This should not happen!")
219                         (lookup name (funcs sess))
220                 HWFunction inports outport = hwfunc
221
222 data PortNameMap =
223         Tuple [PortNameMap]
224         | Port  String
225   deriving (Show)
226
227 -- Generate a port name map (or multiple for tuple types) in the given direction for
228 -- each type given.
229 getPortNameMapForTys :: String -> Int -> [Type] -> [PortNameMap]
230 getPortNameMapForTys prefix num [] = [] 
231 getPortNameMapForTys prefix num (t:ts) =
232         (getPortNameMapForTy (prefix ++ show num) t) : getPortNameMapForTys prefix (num + 1) ts
233
234 getPortNameMapForTy     :: String -> Type -> PortNameMap
235 getPortNameMapForTy name ty =
236         if (TyCon.isTupleTyCon tycon) then
237                 -- Expand tuples we find
238                 Tuple (getPortNameMapForTys name 0 args)
239         else -- Assume it's a type constructor application, ie simple data type
240                 -- TODO: Add type?
241                 Port name
242         where
243                 (tycon, args) = Type.splitTyConApp ty 
244
245 data HWFunction = HWFunction { -- A function that is available in hardware
246         inPorts   :: [PortNameMap],
247         outPort   :: PortNameMap
248         --entity    :: AST.EntityDec
249 } deriving (Show)
250
251 -- Turns a CoreExpr describing a function into a description of its input and
252 -- output ports.
253 mkHWFunction ::
254         CoreBind                                   -- The core binder to generate the interface for
255         -> VHDLState (String, HWFunction)          -- The name of the function and its interface
256
257 mkHWFunction (NonRec var expr) =
258                 return (name, HWFunction inports outport)
259         where
260                 name = (getOccString var)
261                 ty = CoreUtils.exprType expr
262                 (fargs, res) = Type.splitFunTys ty
263                 args = if length fargs == 1 then fargs else (init fargs)
264                 --state = if length fargs == 1 then () else (last fargs)
265                 inports = case args of
266                         -- Handle a single port specially, to prevent an extra 0 in the name
267                         [port] -> [getPortNameMapForTy "portin" port]
268                         ps     -> getPortNameMapForTys "portin" 0 ps
269                 outport = getPortNameMapForTy "portout" res
270
271 mkHWFunction (Rec _) =
272         error "Recursive binders not supported"
273
274 data VHDLSession = VHDLSession {
275         nameCount :: Int,                      -- A counter that can be used to generate unique names
276         funcs     :: [(String, HWFunction)]    -- All functions available, indexed by name
277 } deriving (Show)
278
279 type VHDLState = State.State VHDLSession
280
281 -- Add the function to the session
282 addFunc :: String -> HWFunction -> VHDLState ()
283 addFunc name f = do
284         fs <- State.gets funcs -- Get the funcs element form the session
285         State.modify (\x -> x {funcs = (name, f) : fs }) -- Prepend name and f
286
287 builtin_funcs = 
288         [ 
289                 ("hwxor", HWFunction [Port "a", Port "b"] (Port "o")),
290                 ("hwand", HWFunction [Port "a", Port "b"] (Port "o"))
291         ]