Split off selector case creation code into CoreTools.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 85be0d0ef61ae37050835a9784185a538db04990..8bc2ef0447bea16a77f70b6fa62651cc93e9d4cf 100644 (file)
@@ -358,10 +358,10 @@ needsInline f = do
         case norm_maybe of
           -- Noth normalizeable
           Nothing -> return Nothing 
-          Just norm -> case splitNormalized norm of
+          Just norm -> case splitNormalizedNonRep norm of
             -- The function has just a single binding, so that's simple
             -- enough to inline.
-            (args, [bind], res) -> return $ Just norm
+            (args, [bind], Var res) -> return $ Just norm
             -- More complicated function, don't inline
             _ -> return Nothing
             
@@ -569,12 +569,9 @@ casesimpl c expr@(Case scrut bndr ty alts) | not bndr_used = do
         -- inlinenonrep).
         if (not wild) && repr
           then do
-            -- Create on new binder that will actually capture a value in this
+            caseexpr <- Trans.lift $ mkSelCase scrut i
+            -- Create a new binder that will actually capture a value in this
             -- case statement, and return it.
-            let bty = (Id.idType b)
-            id <- Trans.lift $ mkInternalVar "sel" bty
-            let binders = take i wildbndrs ++ [id] ++ drop (i+1) wildbndrs
-            let caseexpr = Case scrut b bty [(con, binders, Var id)]
             return (wildbndrs!!i, Just (b, caseexpr))
           else 
             -- Just leave the original binder in place, and don't generate an
@@ -852,14 +849,21 @@ normalizeExpr what expr = do
        return expr'
 
 -- | Split a normalized expression into the argument binders, top level
---   bindings and the result binder.
+--   bindings and the result binder. This function returns an error if
+--   the type of the expression is not representable.
 splitNormalized ::
   CoreExpr -- ^ The normalized expression
   -> ([CoreBndr], [Binding], CoreBndr)
-splitNormalized expr = (args, binds, res)
+splitNormalized expr = 
+  case splitNormalizedNonRep expr of
+    (args, binds, Var res) -> (args, binds, res)
+    _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- Split a normalized expression, whose type can be unrepresentable.
+splitNormalizedNonRep::
+  CoreExpr -- ^ The normalized expression
+  -> ([CoreBndr], [Binding], CoreExpr)
+splitNormalizedNonRep expr = (args, binds, resexpr)
   where
     (args, letexpr) = CoreSyn.collectBinders expr
     (binds, resexpr) = flattenLets letexpr
-    res = case resexpr of 
-      (Var x) -> x
-      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"