Add support for multiple-constructor datatypes with fields.
[matthijs/master-project/cλash.git] / clash / CLasH / VHDL / VHDLTools.hs
index 7677369a59c784c516464e9030eb07742fdfa257..639452bcbef3f6a00aa6f1189e88e70e343d4b59 100644 (file)
@@ -356,42 +356,62 @@ mkTyConHType tycon args =
   case TyCon.tyConDataCons tycon of
     -- Not an algebraic type
     [] -> return $ Left $ "VHDLTools.mkTyConHType: Only custom algebraic types are supported: " ++ pprString tycon
-    [dc] -> do
-      let arg_tys = DataCon.dataConRepArgTys dc
-      let real_arg_tys = map (CoreSubst.substTy subst) arg_tys
-      let real_arg_tys_nostate = filter (\x -> not (isStateType x)) real_arg_tys
-      elem_htys_either <- mapM mkHTypeEither real_arg_tys_nostate
-      case Either.partitionEithers elem_htys_either of
-        ([], [elem_hty]) ->
-          return $ Right elem_hty
-        -- No errors in element types
-        ([], elem_htys) ->
-          return $ Right $ AggrType (nameToString (TyCon.tyConName tycon)) elem_htys
-        -- There were errors in element types
-        (errors, _) -> return $ Left $
-          "\nVHDLTools.mkTyConHType: Can not construct type for: " ++ pprString tycon ++ "\n because no type can be construced for some of the arguments.\n"
-          ++ (concat errors)
     dcs -> do
-      let arg_tys = concatMap DataCon.dataConRepArgTys dcs
-      let real_arg_tys = map (CoreSubst.substTy subst) arg_tys
-      case real_arg_tys of
-        [] ->
-          return $ Right $ EnumType (nameToString (TyCon.tyConName tycon)) (map (nameToString . DataCon.dataConName) dcs)
-        xs -> return $ Left $
-          "VHDLTools.mkTyConHType: Only enum-like constructor datatypes supported: " ++ pprString dcs ++ "\n"
+      let arg_tyss = map DataCon.dataConRepArgTys dcs
+      let enum_ty = EnumType name (map (nameToString . DataCon.dataConName) dcs)
+      case (concat arg_tyss) of
+        -- No arguments, this is just an enumeration type
+        [] -> return (Right enum_ty)
+        -- At least one argument, this becomes an aggregate type
+        _ -> do
+          -- Resolve any type arguments to this type
+          let real_arg_tyss = map (map (CoreSubst.substTy subst)) arg_tyss
+          -- Remove any state type fields
+          let real_arg_tyss_nostate = map (filter (\x -> not (isStateType x))) real_arg_tyss
+          elem_htyss_either <- mapM (mapM mkHTypeEither) real_arg_tyss_nostate
+          let (errors, elem_htyss) = unzip (map Either.partitionEithers elem_htyss_either)
+          case errors of
+            [] -> case (dcs, concat elem_htyss) of
+                -- A single constructor with a single (non-state) field?
+                ([dc], [elem_hty]) -> return $ Right elem_hty
+                -- If we get here, then all of the argument types were state
+                -- types (we check for enumeration types at the top). Not
+                -- sure how to handle this, so error out for now.
+                (_, []) -> error $ "ADT with only State elements (or something like that?) Dunno how to handle this yet. Tycon: " ++ pprString tycon ++ " Arguments: " ++ pprString args
+                -- A full ADT (with multiple fields and one or multiple
+                -- constructors).
+                (_, elem_htys) -> do
+                  let (_, fieldss) = List.mapAccumL (List.mapAccumL label_field) labels elem_htyss
+                  -- Only put in an enumeration as part of the aggregation
+                  -- when there are multiple datacons
+                  let enum_ty_part = case dcs of
+                                      [dc] -> Nothing
+                                      _ -> Just ("constructor", enum_ty)
+                  -- Create the AggrType HType
+                  return $ Right $ AggrType name enum_ty_part fieldss
+                -- There were errors in element types
+            errors -> return $ Left $
+              "\nVHDLTools.mkTyConHType: Can not construct type for: " ++ pprString tycon ++ "\n because no type can be construced for some of the arguments.\n"
+              ++ (concat $ concat errors)
   where
+    name = (nameToString (TyCon.tyConName tycon))
     tyvars = TyCon.tyConTyVars tycon
     subst = CoreSubst.extendTvSubstList CoreSubst.emptySubst (zip tyvars args)
+    -- Label a field by taking the first available label and returning
+    -- the rest.
+    label_field :: [String] -> HType -> ([String], (String, HType))
+    label_field (l:ls) htype = (ls, (l, htype))
+    labels = map (:[]) ['A'..'Z']
 
--- Translate a Haskell type to a VHDL type, generating a new type if needed.
--- Returns an error value, using the given message, when no type could be
--- created. Returns Nothing when the type is valid, but empty.
 vhdlTy :: (TypedThing t, Outputable.Outputable t) => 
   String -> t -> TypeSession (Maybe AST.TypeMark)
 vhdlTy msg ty = do
   htype <- mkHType msg ty
   vhdlTyMaybe htype
 
+-- | Translate a Haskell type to a VHDL type, generating a new type if needed.
+-- Returns an error value, using the given message, when no type could be
+-- created. Returns Nothing when the type is valid, but empty.
 vhdlTyMaybe :: HType -> TypeSession (Maybe AST.TypeMark)
 vhdlTyMaybe htype = do
   typemap <- MonadState.get tsTypes
@@ -429,17 +449,45 @@ construct_vhdl_ty htype =
 mkTyconTy :: HType -> TypeSession TypeMapRec
 mkTyconTy htype =
   case htype of
-    (AggrType tycon args) -> do
-      elemTysMaybe <- mapM vhdlTyMaybe args
-      case Maybe.catMaybes elemTysMaybe of
-        [] -> -- No non-empty members
+    (AggrType name enum_field_maybe fieldss) -> do
+      let (labelss, elem_htypess) = unzip (map unzip fieldss)
+      elemTyMaybess <- mapM (mapM vhdlTyMaybe) elem_htypess
+      let elem_tyss = map Maybe.catMaybes elemTyMaybess
+      case concat elem_tyss of
+        [] -> -- No non-empty fields
           return Nothing
-        elem_tys -> do
-          let elems = zipWith AST.ElementDec recordlabels elem_tys  
-          let elem_names = concatMap prettyShow elem_tys
-          let ty_id = mkVHDLExtId $ tycon ++ elem_names
-          let ty_def = AST.TDR $ AST.RecordTypeDef elems
-          let tupshow = mkTupleShow elem_tys ty_id
+        _ -> do
+          let reclabelss = map (map mkVHDLBasicId) labelss
+          let elemss = zipWith (zipWith AST.ElementDec) reclabelss elem_tyss
+          let elem_names = concatMap (concatMap prettyShow) elem_tyss
+          let ty_id = mkVHDLExtId $ name ++ elem_names
+          -- Find out if we need to add an extra field at the start of
+          -- the record type containing the constructor (only needed
+          -- when there's more than one constructor).
+          enum_ty_maybe <- case enum_field_maybe of
+            Nothing -> return Nothing
+            Just (_, enum_htype) -> do
+              enum_ty_maybe' <- vhdlTyMaybe enum_htype
+              case enum_ty_maybe' of
+                Nothing -> error $ "Couldn't translate enumeration type part of AggrType: " ++ show htype
+                -- Note that the first Just means the type is
+                -- translateable, while the second Just means that there
+                -- is a enum_ty at all (e.g., there's multiple
+                -- constructors).
+                Just enum_ty -> return $ Just enum_ty
+          -- Create an record field declaration for the first
+          -- constructor field, if needed.
+          enum_dec_maybe <- case enum_field_maybe of
+            Nothing -> return $ Nothing
+            Just (enum_name, enum_htype) -> do
+              enum_vhdl_ty_maybe <- vhdlTyMaybe  enum_htype
+              let enum_vhdl_ty = Maybe.fromMaybe (error $ "\nVHDLTools.mkTyconTy: Enumeration field should not have empty type: " ++ show enum_htype) enum_vhdl_ty_maybe
+              return $ Just $ AST.ElementDec (mkVHDLBasicId enum_name) enum_vhdl_ty
+          -- Turn the maybe into a list, so we can prepend it.
+          let enum_decs = Maybe.maybeToList enum_dec_maybe
+          let enum_tys = Maybe.maybeToList enum_ty_maybe
+          let ty_def = AST.TDR $ AST.RecordTypeDef (enum_decs ++ concat elemss)
+          let tupshow = mkTupleShow (enum_tys ++ concat elem_tyss) ty_id
           MonadState.modify tsTypeFuns $ Map.insert (htype, showIdString) (showId, tupshow)
           return $ Just (ty_id, Just $ Left ty_def)
     (EnumType tycon dcs) -> do
@@ -450,9 +498,6 @@ mkTyconTy htype =
       MonadState.modify tsTypeFuns $ Map.insert (htype, showIdString) (showId, enumShow)
       return $ Just (ty_id, Just $ Left ty_def)
     otherwise -> error $ "\nVHDLTools.mkTyconTy: Called for HType that is neiter a AggrType or EnumType: " ++ show htype
-  where
-    -- Generate a bunch of labels for fields of a record
-    recordlabels = map (\c -> mkVHDLBasicId [c]) ['A'..'Z']
 
 -- | Create a VHDL vector type
 mkVectorTy ::
@@ -515,23 +560,27 @@ mkSignedTy size = do
   let ty_def = AST.SubtypeIn signedTM (Just range)
   return (Just (ty_id, Just $ Right ty_def))
 
--- Finds the field labels for VHDL type generated for the given Core type,
--- which must result in a record type.
-getFieldLabels :: Type.Type -> TypeSession [AST.VHDLId]
-getFieldLabels ty = do
-  -- Ensure that the type is generated (but throw away it's VHDLId)
-  let error_msg = "\nVHDLTools.getFieldLabels: Can not get field labels, because: " ++ pprString ty ++ "can not be generated." 
-  vhdlTy error_msg ty
-  -- Get the types map, lookup and unpack the VHDL TypeDef
-  types <- MonadState.get tsTypes
-  -- Assume the type for which we want labels is really translatable
-  htype <- mkHType error_msg ty
-  case Map.lookup htype types of
-    Nothing -> error $ "\nVHDLTools.getFieldLabels: Type not found? This should not happen!\nLooking for type: " ++ (pprString ty) ++ "\nhtype: " ++ (show htype) 
-    Just Nothing -> return [] -- The type is empty
-    Just (Just (_, Just (Left (AST.TDR (AST.RecordTypeDef elems))))) -> return $ map (\(AST.ElementDec id _) -> id) elems
-    Just (Just (_, Just vty)) -> error $ "\nVHDLTools.getFieldLabels: Type not a record type? This should not happen!\nLooking for type: " ++ pprString (ty) ++ "\nhtype: " ++ (show htype) ++ "\nFound type: " ++ (show vty)
-    
+-- Finds the field labels and types for aggregation HType. Returns an
+-- error on other types.
+getFields ::
+  HType                -- ^ The HType to get fields for
+  -> Int               -- ^ The constructor to get fields for (e.g., 0
+                       --   for the first constructor, etc.)
+  -> [(String, HType)] -- ^ A list of fields, with their name and type
+getFields htype dc_i = case htype of
+  (AggrType name _ fieldss) 
+    | dc_i >= 0 && dc_i < length fieldss -> fieldss!!dc_i
+    | otherwise -> error $ "Invalid constructor index: " ++ (show dc_i) ++ ". No such constructor in HType: " ++ (show htype)
+  _ -> error $ "Can't get fields from non-aggregate HType: " ++ show htype
+
+-- Finds the field labels for an aggregation type, as VHDLIds.
+getFieldLabels ::
+  HType                -- ^ The HType to get field labels for
+  -> Int               -- ^ The constructor to get fields for (e.g., 0
+                       --   for the first constructor, etc.)
+  -> [AST.VHDLId]      -- ^ The labels
+getFieldLabels htype dc_i = ((map mkVHDLBasicId) . (map fst)) (getFields htype dc_i)
+
 mktydecl :: (AST.VHDLId, Maybe (Either AST.TypeDef AST.SubtypeIn)) -> Maybe AST.PackageDecItem
 mytydecl (_, Nothing) = Nothing
 mktydecl (ty_id, Just (Left ty_def)) = Just $ AST.PDITD $ AST.TypeDec ty_id ty_def