Make casesimpl support multiple-alt cases with fields.
[matthijs/master-project/cλash.git] / clash / CLasH / Normalize.hs
index ea171ca05f7b783a6731845e037a4c7216ee9291..72885b7b9cd4f3c9ae00f9853ac0f102ea91b57b 100644 (file)
@@ -380,7 +380,7 @@ funextract c expr@(App _ _) | is_var fexpr = do
     -- We could use is_applicable here instead of is_fun, but I think
     -- arguments to functions could only have forall typing when existential
     -- typing is enabled. Not sure, though.
-    doarg arg | not (is_simple arg) && is_fun arg = do
+    doarg arg | not (is_simple arg) && is_fun arg && not (has_free_tyvars arg) = do
       -- Create a new top level binding that binds the argument. Its body will
       -- be extended with lambda expressions, to take any free variables used
       -- by the argument expression.
@@ -410,15 +410,14 @@ funextract c expr = return expr
 -- Make sure the scrutinee of a case expression is a local variable
 -- reference.
 scrutsimpl :: Transform
--- Don't touch scrutinees that are already simple
-scrutsimpl c expr@(Case (Var _) _ _ _) = return expr
--- Replace all other cases with a let that binds the scrutinee and a new
+-- Replace a case expression with a let that binds the scrutinee and a new
 -- simple scrutinee, but only when the scrutinee is representable (to prevent
 -- loops with inlinenonrep, though I don't think a non-representable scrutinee
--- will be supported anyway...) 
+-- will be supported anyway...) and is not a local variable already.
 scrutsimpl c expr@(Case scrut b ty alts) = do
   repr <- isRepr scrut
-  if repr
+  local_var <- Trans.lift $ is_local_var scrut
+  if repr && not local_var
     then do
       id <- Trans.lift $ mkBinderFor scrut "scrut"
       change $ Let (NonRec id scrut) (Case (Var id) b ty alts)
@@ -490,7 +489,9 @@ casesimpl c expr@(Case scrut bndr ty alts) | not bndr_used = do
   -- Wilden the binders of one alt, producing a list of bindings as a
   -- sideeffect.
   doalt :: CoreAlt -> TransformMonad ([(CoreBndr, CoreExpr)], CoreAlt)
-  doalt (con, bndrs, expr) = do
+  doalt (LitAlt _, _, _) = error $ "Don't know how to handle LitAlt in case expression: " ++ pprString expr
+  doalt alt@(DEFAULT, [], expr) = return ([], alt)
+  doalt (DataAlt dc, bndrs, expr) = do
     -- Make each binder wild, if possible
     bndrs_res <- Monad.zipWithM dobndr bndrs [0..]
     let (newbndrs, bindings_maybe) = unzip bndrs_res
@@ -500,7 +501,7 @@ casesimpl c expr@(Case scrut bndr ty alts) | not bndr_used = do
     let uses_bndrs = not $ VarSet.isEmptyVarSet $ CoreFVs.exprSomeFreeVars (`elem` newbndrs) expr
     (exprbinding_maybe, expr') <- doexpr expr uses_bndrs
     -- Create a new alternative
-    let newalt = (con, newbndrs, expr')
+    let newalt = (DataAlt dc, newbndrs, expr')
     let bindings = Maybe.catMaybes (bindings_maybe ++ [exprbinding_maybe])
     return (bindings, newalt)
     where
@@ -522,7 +523,8 @@ casesimpl c expr@(Case scrut bndr ty alts) | not bndr_used = do
         -- inlinenonrep).
         if (not wild) && repr
           then do
-            caseexpr <- Trans.lift $ mkSelCase scrut i
+            let dc_i = datacon_index (CoreUtils.exprType scrut) dc
+            caseexpr <- Trans.lift $ mkSelCase scrut dc_i i
             -- Create a new binder that will actually capture a value in this
             -- case statement, and return it.
             return (wildbndrs!!i, Just (b, caseexpr))
@@ -753,7 +755,7 @@ inlinenonrepresult :: Transform
 -- that is fully applied (i.e., dos not have a function type) but is not
 -- representable. We apply in any context, since non-representable
 -- expressions are generally left alone and can occur anywhere.
-inlinenonrepresult context expr | not (is_fun expr) =
+inlinenonrepresult context expr | not (is_applicable expr) && not (has_free_tyvars expr) =
   case collectArgs expr of
     (Var f, args) | not (Id.isDictId f) -> do
       repr <- isRepr expr
@@ -794,7 +796,7 @@ inlinenonrepresult context expr | not (is_fun expr) =
                   res_bndr <- Trans.lift $ mkBinderFor newapp "res"
                   -- Create extractor case expressions to extract each of the
                   -- free variables from the tuple.
-                  sel_cases <- Trans.lift $ mapM (mkSelCase (Var res_bndr)) [0..n_free_vars-1]
+                  sel_cases <- Trans.lift $ mapM (mkSelCase (Var res_bndr) 0) [0..n_free_vars-1]
 
                   -- Bind the res_bndr to the result of the new application
                   -- and each of the free variables to the corresponding
@@ -821,6 +823,10 @@ inlinenonrepresult context expr | not (is_fun expr) =
 -- Leave all other expressions unchanged
 inlinenonrepresult c expr = return expr
 
+----------------------------------------------------------------
+-- Type-class transformations
+----------------------------------------------------------------
+
 --------------------------------
 -- ClassOp resolution
 --------------------------------
@@ -952,7 +958,7 @@ letmerge c expr = return expr
 -- What transforms to run?
 transforms = [ ("inlinedict", inlinedict)
              , ("inlinetoplevel", inlinetoplevel)
-             -- , ("inlinenonrepresult", inlinenonrepresult)
+             , ("inlinenonrepresult", inlinenonrepresult)
              , ("knowncase", knowncase)
              , ("classopresolution", classopresolution)
              , ("argprop", argprop)