Add support for multiple-constructor datatypes with fields.
[matthijs/master-project/cλash.git] / clash / CLasH / VHDL / Generate.hs
index 3d31529a86cc3c7b46b0930a4dc0aa748283c2cf..c9febe3d8c74fa70d51fcd9ac76f9137ca9045a0 100644 (file)
@@ -251,7 +251,7 @@ mkConcSm (bndr, expr@(CoreSyn.Case (CoreSyn.Var scrut) b ty [alt]))
         then do
           bndrs' <- Monad.filterM hasNonEmptyType bndrs
           case List.elemIndex sel_bndr bndrs' of
-            Just i -> do
+            Just sel_i -> do
               htypeScrt <- MonadState.lift tsType $ mkHTypeEither (Var.varType scrut)
               htypeBndr <- MonadState.lift tsType $ mkHTypeEither (Var.varType bndr)
               case htypeScrt == htypeBndr of
@@ -261,9 +261,10 @@ mkConcSm (bndr, expr@(CoreSyn.Case (CoreSyn.Var scrut) b ty [alt]))
                   return ([mkUncondAssign (Left bndr) sel_expr], [])
                 otherwise ->
                   case htypeScrt of
-                    Right (AggrType _ _) -> do
-                      labels <- MonadState.lift tsType $ getFieldLabels (Id.idType scrut)
-                      let label = labels!!i
+                    Right htype@(AggrType _ _ _) -> do
+                      let dc_i = datacon_index (Id.idType scrut) dc
+                      let labels = getFieldLabels htype dc_i
+                      let label = labels!!sel_i
                       let sel_name = mkSelectedName (varToVHDLName scrut) label
                       let sel_expr = AST.PrimName sel_name
                       return ([mkUncondAssign (Left bndr) sel_expr], [])
@@ -282,13 +283,34 @@ mkConcSm (bndr, expr@(CoreSyn.Case (CoreSyn.Var scrut) b ty [alt]))
 -- binders in the alts and only variables in the case values and a variable
 -- for a scrutinee. We check the constructor of the second alt, since the
 -- first is the default case, if there is any.
-mkConcSm (bndr, (CoreSyn.Case (CoreSyn.Var scrut) _ _ (alt:alts))) = do
-  scrut' <- MonadState.lift tsType $ varToVHDLExpr scrut
-  -- Omit first condition, which is the default
-  altcons <- MonadState.lift tsType $ mapM (altconToVHDLExpr . (\(con,_,_) -> con)) alts
-  let cond_exprs = map (\x -> scrut' AST.:=: x) altcons
+mkConcSm (bndr, expr@(CoreSyn.Case (CoreSyn.Var scrut) _ _ alts)) = do
+  htype <- MonadState.lift tsType $ mkHType ("\nVHDL.mkConcSm: Unrepresentable scrutinee type? Expression: " ++ pprString expr) scrut
+  -- Turn the scrutinee into a VHDLExpr
+  scrut_expr <- MonadState.lift tsType $ varToVHDLExpr scrut
+  (enums, cmp) <- case htype of
+    EnumType _ enums -> do
+      -- Enumeration type, compare with the scrutinee directly
+      return (map stringToVHDLExpr enums, scrut_expr)
+    AggrType _ (Just (name, EnumType _ enums)) _ -> do
+      -- Extract the enumeration field from the aggregation
+      let sel_name = mkSelectedName (varToVHDLName scrut) (mkVHDLBasicId name)
+      let sel_expr = AST.PrimName sel_name
+      return (map stringToVHDLExpr enums, sel_expr)
+    (BuiltinType "Bit") -> do
+      let enums = [AST.PrimLit "'1'", AST.PrimLit "'0'"]
+      return (enums, scrut_expr)
+    (BuiltinType "Bool") -> do
+      let enums = [AST.PrimLit "true", AST.PrimLit "false"]
+      return (enums, scrut_expr)
+    _ -> error $ "\nSelector case on weird scrutinee: " ++ pprString scrut ++ " scrutinee type: " ++ pprString (Id.idType scrut)
+  -- Omit first condition, which is the default. Look up each altcon in
+  -- the enums list from the HType to find the actual enum value names.
+  let altcons = map (\(CoreSyn.DataAlt dc, _, _) -> enums!!(datacon_index scrut dc)) (tail alts)
+  -- Compare the (constructor field of the) scrutinee with each of the
+  -- alternatives.
+  let cond_exprs = map (\x -> cmp AST.:=: x) altcons
   -- Rotate expressions to the left, so that the expression related to the default case is the last
-  exprs <- MonadState.lift tsType $ mapM (varToVHDLExpr . (\(_,_,CoreSyn.Var expr) -> expr)) (alts ++ [alt])
+  exprs <- MonadState.lift tsType $ mapM (varToVHDLExpr . (\(_,_,CoreSyn.Var expr) -> expr)) ((tail alts) ++ [head alts])
   return ([mkAltsAssign (Left bndr) cond_exprs exprs], [])
 
 mkConcSm (_, CoreSyn.Case _ _ _ _) = error "\nVHDL.mkConcSm: Not in normal form: Case statement does not have a simple variable as scrutinee"
@@ -725,6 +747,7 @@ genZip' :: (Either CoreSyn.CoreBndr AST.VHDLName) -> CoreSyn.CoreBndr -> [Var.Va
 genZip' (Left res) f args@[arg1, arg2] = do {
     -- Setup the generate scheme
   ; len <- MonadState.lift tsType $ tfp_to_int $ (tfvec_len_ty . Var.varType) res
+  ; res_htype <- MonadState.lift tsType $ mkHType "\nGenerate.genZip: Invalid result type" (tfvec_elem (Var.varType res))
           -- TODO: Use something better than varToString
   ; let { label           = mkVHDLExtId ("zipVector" ++ (varToString res))
         ; n_id            = mkVHDLBasicId "n"
@@ -734,8 +757,8 @@ genZip' (Left res) f args@[arg1, arg2] = do {
         ; resname'        = mkIndexedName (varToVHDLName res) n_expr
         ; argexpr1        = vhdlNameToVHDLExpr $ mkIndexedName (varToVHDLName arg1) n_expr
         ; argexpr2        = vhdlNameToVHDLExpr $ mkIndexedName (varToVHDLName arg2) n_expr
-        } ; 
-  ; labels <- MonadState.lift tsType $ getFieldLabels (tfvec_elem (Var.varType res))
+        ; labels          = getFieldLabels res_htype 0
+        }
   ; let { resnameA    = mkSelectedName resname' (labels!!0)
         ; resnameB    = mkSelectedName resname' (labels!!1)
         ; resA_assign = mkUncondAssign (Right resnameA) argexpr1
@@ -750,8 +773,10 @@ genFst :: BuiltinBuilder
 genFst = genNoInsts $ genVarArgs genFst'
 genFst' :: (Either CoreSyn.CoreBndr AST.VHDLName) -> CoreSyn.CoreBndr -> [Var.Var] -> TranslatorSession [AST.ConcSm]
 genFst' (Left res) f args@[arg] = do {
-  ; labels <- MonadState.lift tsType $ getFieldLabels (Var.varType arg)
-  ; let { argexpr'    = varToVHDLName arg
+  ; arg_htype <- MonadState.lift tsType $ mkHType "\nGenerate.genFst: Invalid argument type" (Var.varType arg)
+  ; let { 
+        ; labels      = getFieldLabels arg_htype 0
+        ; argexpr'    = varToVHDLName arg
         ; argexprA    = vhdlNameToVHDLExpr $ mkSelectedName argexpr' (labels!!0)
         ; assign      = mkUncondAssign (Left res) argexprA
         } ;
@@ -764,8 +789,10 @@ genSnd :: BuiltinBuilder
 genSnd = genNoInsts $ genVarArgs genSnd'
 genSnd' :: (Either CoreSyn.CoreBndr AST.VHDLName) -> CoreSyn.CoreBndr -> [Var.Var] -> TranslatorSession [AST.ConcSm]
 genSnd' (Left res) f args@[arg] = do {
-  ; labels <- MonadState.lift tsType $ getFieldLabels (Var.varType arg)
-  ; let { argexpr'    = varToVHDLName arg
+  ; arg_htype <- MonadState.lift tsType $ mkHType "\nGenerate.genSnd: Invalid argument type" (Var.varType arg)
+  ; let { 
+        ; labels      = getFieldLabels arg_htype 0
+        ; argexpr'    = varToVHDLName arg
         ; argexprB    = vhdlNameToVHDLExpr $ mkSelectedName argexpr' (labels!!1)
         ; assign      = mkUncondAssign (Left res) argexprB
         } ;
@@ -785,9 +812,11 @@ genUnzip' (Left res) f args@[arg] = do
   -- resulting VHDL, making the the unzip no longer required.
   case htype of
     -- A normal vector containing two-tuples
-    VecType _ (AggrType _ [_, _]) -> do {
+    VecType _ (AggrType _ [_, _]) -> do {
         -- Setup the generate scheme
       ; len <- MonadState.lift tsType $ tfp_to_int $ (tfvec_len_ty . Var.varType) arg
+      ; arg_htype <- MonadState.lift tsType $ mkHType "\nGenerate.genUnzip: Invalid argument type" (Var.varType arg)
+      ; res_htype <- MonadState.lift tsType $ mkHType "\nGenerate.genUnzip: Invalid result type" (Var.varType res)
         -- TODO: Use something better than varToString
       ; let { label           = mkVHDLExtId ("unzipVector" ++ (varToString res))
             ; n_id            = mkVHDLBasicId "n"
@@ -796,9 +825,9 @@ genUnzip' (Left res) f args@[arg] = do
             ; genScheme       = AST.ForGn n_id range
             ; resname'        = varToVHDLName res
             ; argexpr'        = mkIndexedName (varToVHDLName arg) n_expr
+            ; reslabels       = getFieldLabels res_htype 0
+            ; arglabels       = getFieldLabels arg_htype 0
             } ;
-      ; reslabels <- MonadState.lift tsType $ getFieldLabels (Var.varType res)
-      ; arglabels <- MonadState.lift tsType $ getFieldLabels (tfvec_elem (Var.varType arg))
       ; let { resnameA    = mkIndexedName (mkSelectedName resname' (reslabels!!0)) n_expr
             ; resnameB    = mkIndexedName (mkSelectedName resname' (reslabels!!1)) n_expr
             ; argexprA    = vhdlNameToVHDLExpr $ mkSelectedName argexpr' (arglabels!!0)
@@ -811,9 +840,9 @@ genUnzip' (Left res) f args@[arg] = do
       }
     -- Both elements of the tuple were state, so they've disappeared. No
     -- need to do anything
-    VecType _ (AggrType _ []) -> return []
+    VecType _ (AggrType _ []) -> return []
     -- A vector containing aggregates with more than two elements?
-    VecType _ (AggrType _ _) -> error $ "Unzipping a value that is not a vector of two-tuples? Value: " ++ pprString arg ++ "\nType: " ++ pprString (Var.varType arg)
+    VecType _ (AggrType _ _ _) -> error $ "Unzipping a value that is not a vector of two-tuples? Value: " ++ pprString arg ++ "\nType: " ++ pprString (Var.varType arg)
     -- One of the elements of the tuple was state, so there won't be a
     -- tuple (record) in the VHDL output. We can just do a plain
     -- assignment, then.
@@ -997,9 +1026,11 @@ genSplit = genNoInsts $ genVarArgs genSplit'
 
 genSplit' :: (Either CoreSyn.CoreBndr AST.VHDLName) -> CoreSyn.CoreBndr -> [Var.Var] -> TranslatorSession [AST.ConcSm]
 genSplit' (Left res) f args@[vecIn] = do {
-  ; labels <- MonadState.lift tsType $ getFieldLabels (Var.varType res)
   ; len <- MonadState.lift tsType $ tfp_to_int $ (tfvec_len_ty . Var.varType) vecIn
-  ; let { block_label = mkVHDLExtId ("split" ++ (varToString vecIn))
+  ; res_htype <- MonadState.lift tsType $ mkHType "\nGenerate.genSplit': Invalid result type" (Var.varType res)
+  ; let { 
+        ; labels    = getFieldLabels res_htype 0
+        ; block_label = mkVHDLExtId ("split" ++ (varToString vecIn))
         ; halflen   = round ((fromIntegral len) / 2)
         ; rangeL    = vecSlice (AST.PrimLit "0") (AST.PrimLit $ show (halflen - 1))
         ; rangeR    = vecSlice (AST.PrimLit $ show halflen) (AST.PrimLit $ show (len - 1))
@@ -1039,16 +1070,17 @@ genApplication dst f args = do
             -- It's a datacon. Create a record from its arguments.
             Left bndr -> do
               -- We have the bndr, so we can get at the type
-              htype <- MonadState.lift tsType $ mkHTypeEither (Var.varType bndr)
+              htype_either <- MonadState.lift tsType $ mkHTypeEither (Var.varType bndr)
               let argsNostate = filter (\x -> not (either hasStateType (\x -> False) x)) args
               case argsNostate of
                 [arg] -> do
                   [arg'] <- argsToVHDLExprs [arg]
                   return ([mkUncondAssign dst arg'], [])
                 otherwise ->
-                  case htype of
-                    Right (AggrType _ _) -> do
-                      labels <- MonadState.lift tsType $ getFieldLabels (Var.varType bndr)
+                  case htype_either of
+                    Right htype@(AggrType _ _ _) -> do
+                      let dc_i = datacon_index (Var.varType bndr) dc
+                      let labels = getFieldLabels htype dc_i
                       args' <- argsToVHDLExprs argsNostate
                       return (zipWith mkassign labels args', [])
                       where