Put getInstantiations in the State monad.
[matthijs/master-project/cλash.git] / Translator.hs
1 module Main(main) where
2 import GHC
3 import CoreSyn
4 import qualified CoreUtils
5 import qualified Var
6 import qualified Type
7 import qualified TyCon
8 import qualified DataCon
9 import qualified Maybe
10 import qualified Module
11 import qualified Control.Monad.State as State
12 import Name
13 import Data.Generics
14 import NameEnv ( lookupNameEnv )
15 import HscTypes ( cm_binds, cm_types )
16 import MonadUtils ( liftIO )
17 import Outputable ( showSDoc, ppr )
18 import GHC.Paths ( libdir )
19 import DynFlags ( defaultDynFlags )
20 import List ( find )
21 -- The following modules come from the ForSyDe project. They are really
22 -- internal modules, so ForSyDe.cabal has to be modified prior to installing
23 -- ForSyDe to get access to these modules.
24 import qualified ForSyDe.Backend.VHDL.AST as AST
25 import qualified ForSyDe.Backend.VHDL.Ppr
26 import qualified ForSyDe.Backend.Ppr
27 -- This is needed for rendering the pretty printed VHDL
28 import Text.PrettyPrint.HughesPJ (render)
29
30 main = 
31                 do
32                         defaultErrorHandler defaultDynFlags $ do
33                                 runGhc (Just libdir) $ do
34                                         dflags <- getSessionDynFlags
35                                         setSessionDynFlags dflags
36                                         --target <- guessTarget "adder.hs" Nothing
37                                         --liftIO (print (showSDoc (ppr (target))))
38                                         --liftIO $ printTarget target
39                                         --setTargets [target]
40                                         --load LoadAllTargets
41                                         --core <- GHC.compileToCoreSimplified "Adders.hs"
42                                         core <- GHC.compileToCoreSimplified "Adders.hs"
43                                         liftIO $ printBinds (cm_binds core)
44                                         let bind = findBind "half_adder" (cm_binds core)
45                                         let NonRec var expr = bind
46                                         -- Turn bind into VHDL
47                                         let vhdl = State.evalState (mkVHDL bind) (VHDLSession 0 builtin_funcs)
48                                         liftIO $ putStr $ showSDoc $ ppr expr
49                                         liftIO $ putStr "\n\n"
50                                         liftIO $ putStr $ render $ ForSyDe.Backend.Ppr.ppr $ vhdl
51                                         return expr
52         where
53                 -- Turns the given bind into VHDL
54                 mkVHDL bind = do
55                         -- Get the function signature
56                         (name, f) <- mkHWFunction bind
57                         -- Add it to the session
58                         addFunc name f
59                         arch <- getArchitecture bind
60                         return arch
61
62 printTarget (Target (TargetFile file (Just x)) obj Nothing) =
63         print $ show file
64
65 printBinds [] = putStr "done\n\n"
66 printBinds (b:bs) = do
67         printBind b
68         putStr "\n"
69         printBinds bs
70
71 printBind (NonRec b expr) = do
72         putStr "NonRec: "
73         printBind' (b, expr)
74
75 printBind (Rec binds) = do
76         putStr "Rec: \n"        
77         foldl1 (>>) (map printBind' binds)
78
79 printBind' (b, expr) = do
80         putStr $ getOccString b
81         --putStr $ showSDoc $ ppr expr
82         putStr "\n"
83
84 findBind :: String -> [CoreBind] -> CoreBind
85 findBind lookfor =
86         -- This ignores Recs and compares the name of the bind with lookfor,
87         -- disregarding any namespaces in OccName and extra attributes in Name and
88         -- Var.
89         Maybe.fromJust . find (\b -> case b of 
90                 Rec l -> False
91                 NonRec var _ -> lookfor == (occNameString $ nameOccName $ getName var)
92         )
93
94 -- Accepts a port name and an argument to map to it.
95 -- Returns the appropriate line for in the port map
96 getPortMapEntry binds (Port portname) (Var id) = 
97         (Just (AST.unsafeVHDLBasicId portname)) AST.:=>: (AST.ADName (AST.NSimple (AST.unsafeVHDLBasicId signalname)))
98         where
99                 Port signalname = Maybe.fromMaybe
100                         (error $ "Argument " ++ getOccString id ++ "is unknown")
101                         (lookup id binds)
102
103 getPortMapEntry binds _ a = error $ "Unsupported argument: " ++ (showSDoc $ ppr a)
104
105 getInstantiations ::
106         [PortNameMap]                -- The arguments that need to be applied to the
107                                                                                                                          -- expression.
108         -> PortNameMap               -- The output ports that the expression should generate.
109         -> [(CoreBndr, PortNameMap)] -- A list of bindings in effect
110         -> CoreSyn.CoreExpr          -- The expression to generate an architecture for
111         -> VHDLState [AST.ConcSm]    -- The resulting VHDL code
112
113 -- A lambda expression binds the first argument (a) to the binder b.
114 getInstantiations (a:as) outs binds (Lam b expr) =
115         getInstantiations as outs ((b, a):binds) expr
116
117 -- A case expression that checks a single variable and has a single
118 -- alternative, can be used to take tuples apart
119 getInstantiations args outs binds (Case (Var v) b _ [res]) =
120         case altcon of
121                 DataAlt datacon ->
122                         if (DataCon.isTupleCon datacon) then
123                                 getInstantiations args outs binds' expr
124                         else
125                                 error "Data constructors other than tuples not supported"
126                 otherwise ->
127                         error "Case binders other than tuples not supported"
128         where
129                 binds' = (zip bind_vars tuple_ports) ++ binds
130                 (altcon, bind_vars, expr) = res
131                 -- Find the portnamemaps for each of the tuple's elements
132                 Tuple tuple_ports = Maybe.fromMaybe 
133                         (error $ "Case expression uses unknown scrutinee " ++ getOccString v)
134                         (lookup v binds)
135
136 -- An application is an instantiation of a component
137 getInstantiations args outs binds app@(App expr arg) = do
138         let ((Var f), fargs) = collectArgs app
139             name = getOccString f
140         if isTupleConstructor f 
141                 then do
142                         let Tuple outports = outs
143                             (tys, vals) = splitTupleConstructorArgs fargs
144                         insts <- sequence $ zipWith 
145                                 (\outs' expr' -> getInstantiations args outs' binds expr')
146                                 outports vals
147                         return $ concat insts
148                 else do
149                         HWFunction inports outport <- getHWFunc name
150                         let comp = AST.CompInsSm
151                                                 (AST.unsafeVHDLBasicId "app")
152                                                 (AST.IUEntity (AST.NSimple (AST.unsafeVHDLBasicId name)))
153                                                 (AST.PMapAspect ports)
154                             ports = 
155                                     zipWith (getPortMapEntry binds) inports fargs
156                                     ++ mapOutputPorts outport outs
157                         return [AST.CSISm comp]
158
159 getInstantiations args outs binds expr = 
160         error $ "Unsupported expression" ++ (showSDoc $ ppr $ expr)
161
162 -- Is the given name a (binary) tuple constructor
163 isTupleConstructor :: Var.Var -> Bool
164 isTupleConstructor var =
165         Name.isWiredInName name
166         && Name.nameModule name == tuple_mod
167         && (Name.occNameString $ Name.nameOccName name) == "(,)"
168         where
169                 name = Var.varName var
170                 mod = nameModule name
171                 tuple_mod = Module.mkModule (Module.stringToPackageId "ghc-prim") (Module.mkModuleName "GHC.Tuple")
172
173 -- Split arguments into type arguments and value arguments This is probably
174 -- not really sufficient (not sure if Types can actually occur as value
175 -- arguments...)
176 splitTupleConstructorArgs :: [CoreExpr] -> ([CoreExpr], [CoreExpr])
177 splitTupleConstructorArgs (e:es) =
178         case e of
179                 Type t     -> (e:tys, vals)
180                 otherwise  -> (tys, e:vals)
181         where
182                 (tys, vals) = splitTupleConstructorArgs es
183
184 mapOutputPorts ::
185         PortNameMap         -- The output portnames of the component
186         -> PortNameMap      -- The output portnames and/or signals to map these to
187         -> [AST.AssocElem]  -- The resulting output ports
188
189 -- Map the output port of a component to the output port of the containing
190 -- entity.
191 mapOutputPorts (Port portname) (Port signalname) =
192         [(Just (AST.unsafeVHDLBasicId portname)) AST.:=>: (AST.ADName (AST.NSimple (AST.unsafeVHDLBasicId signalname)))]
193
194 -- Map matching output ports in the tuple
195 mapOutputPorts (Tuple ports) (Tuple signals) =
196         concat (zipWith mapOutputPorts ports signals)
197
198 getArchitecture ::
199         CoreBind                  -- The binder to expand into an architecture
200         -> VHDLState AST.ArchBody -- The resulting architecture
201          
202 getArchitecture (Rec _) = error "Recursive binders not supported"
203
204 getArchitecture (NonRec var expr) = do
205         let name = (getOccString var)
206         HWFunction inports outport <- getHWFunc name
207         sess <- State.get
208         insts <- getInstantiations inports outport [] expr
209         return $ AST.ArchBody
210                 (AST.unsafeVHDLBasicId "structural")
211                 -- Use unsafe for now, to prevent pulling in ForSyDe error handling
212                 (AST.NSimple (AST.unsafeVHDLBasicId name))
213                 []
214                 (insts)
215
216 data PortNameMap =
217         Tuple [PortNameMap]
218         | Port  String
219   deriving (Show)
220
221 -- Generate a port name map (or multiple for tuple types) in the given direction for
222 -- each type given.
223 getPortNameMapForTys :: String -> Int -> [Type] -> [PortNameMap]
224 getPortNameMapForTys prefix num [] = [] 
225 getPortNameMapForTys prefix num (t:ts) =
226         (getPortNameMapForTy (prefix ++ show num) t) : getPortNameMapForTys prefix (num + 1) ts
227
228 getPortNameMapForTy     :: String -> Type -> PortNameMap
229 getPortNameMapForTy name ty =
230         if (TyCon.isTupleTyCon tycon) then
231                 -- Expand tuples we find
232                 Tuple (getPortNameMapForTys name 0 args)
233         else -- Assume it's a type constructor application, ie simple data type
234                 -- TODO: Add type?
235                 Port name
236         where
237                 (tycon, args) = Type.splitTyConApp ty 
238
239 data HWFunction = HWFunction { -- A function that is available in hardware
240         inPorts   :: [PortNameMap],
241         outPort   :: PortNameMap
242         --entity    :: AST.EntityDec
243 } deriving (Show)
244
245 -- Turns a CoreExpr describing a function into a description of its input and
246 -- output ports.
247 mkHWFunction ::
248         CoreBind                                   -- The core binder to generate the interface for
249         -> VHDLState (String, HWFunction)          -- The name of the function and its interface
250
251 mkHWFunction (NonRec var expr) =
252                 return (name, HWFunction inports outport)
253         where
254                 name = (getOccString var)
255                 ty = CoreUtils.exprType expr
256                 (fargs, res) = Type.splitFunTys ty
257                 args = if length fargs == 1 then fargs else (init fargs)
258                 --state = if length fargs == 1 then () else (last fargs)
259                 inports = case args of
260                         -- Handle a single port specially, to prevent an extra 0 in the name
261                         [port] -> [getPortNameMapForTy "portin" port]
262                         ps     -> getPortNameMapForTys "portin" 0 ps
263                 outport = getPortNameMapForTy "portout" res
264
265 mkHWFunction (Rec _) =
266         error "Recursive binders not supported"
267
268 data VHDLSession = VHDLSession {
269         nameCount :: Int,                      -- A counter that can be used to generate unique names
270         funcs     :: [(String, HWFunction)]    -- All functions available, indexed by name
271 } deriving (Show)
272
273 type VHDLState = State.State VHDLSession
274
275 -- Add the function to the session
276 addFunc :: String -> HWFunction -> VHDLState ()
277 addFunc name f = do
278         fs <- State.gets funcs -- Get the funcs element from the session
279         State.modify (\x -> x {funcs = (name, f) : fs }) -- Prepend name and f
280
281 -- Lookup the function with the given name in the current session. Errors if
282 -- it was not found.
283 getHWFunc :: String -> VHDLState HWFunction
284 getHWFunc name = do
285         fs <- State.gets funcs -- Get the funcs element from the session
286         return $ Maybe.fromMaybe
287                 (error $ "Function " ++ name ++ "is unknown? This should not happen!")
288                 (lookup name fs)
289
290 builtin_funcs = 
291         [ 
292                 ("hwxor", HWFunction [Port "a", Port "b"] (Port "o")),
293                 ("hwand", HWFunction [Port "a", Port "b"] (Port "o"))
294         ]