Make top level inlining handle non-representable results gracefully.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index a550124e0f01fccaa2dd8bcfb07b0e2f7cfa94b1..620482f703547f65f3dff7eb888742f70e67421d 100644 (file)
@@ -121,11 +121,23 @@ castsimpl c expr = return expr
 castsimpltop = everywhere ("castsimpl", castsimpl)
 
 --------------------------------
--- Ensure that a function that just returns another function (or rather,
--- another top-level binder) is still properly normalized. This is a temporary
--- solution, we should probably integrate this pass with lambdasimpl and
--- letsimpl instead.
+-- Return value simplification
 --------------------------------
+-- Ensure the return value of a function follows proper normal form. eta
+-- expansion ensures the body starts with lambda abstractions, this
+-- transformation ensures that the lambda abstractions always contain a
+-- recursive let and that, when the return value is representable, the
+-- let contains a local variable reference in its body.
+retvalsimpl c expr | all (== LambdaBody) c && not (is_lam expr) && not (is_let expr) = do
+  local_var <- Trans.lift $ is_local_var expr
+  repr <- isRepr expr
+  if not local_var && repr
+    then do
+      id <- Trans.lift $ mkBinderFor expr "res" 
+      change $ Let (Rec [(id, expr)]) (Var id)
+    else
+      return expr
+
 retvalsimpl c expr@(Let (Rec binds) body) | all (== LambdaBody) c = do
   -- Don't extract values that are already a local variable, to prevent
   -- loops with ourselves.
@@ -140,15 +152,6 @@ retvalsimpl c expr@(Let (Rec binds) body) | all (== LambdaBody) c = do
     else
       return expr
 
-retvalsimpl c expr | all (== LambdaBody) c && not (is_lam expr) && not (is_let expr) = do
-  local_var <- Trans.lift $ is_local_var expr
-  repr <- isRepr expr
-  if not local_var && repr
-    then do
-      id <- Trans.lift $ mkBinderFor expr "res" 
-      change $ Let (Rec [(id, expr)]) (Var id)
-    else
-      return expr
 
 -- Leave all other expressions unchanged
 retvalsimpl c expr = return expr
@@ -158,26 +161,14 @@ retvalsimpltop = everywhere ("retvalsimpl", retvalsimpl)
 --------------------------------
 -- let derecursification
 --------------------------------
-letderec, letderectop :: Transform
-letderec c expr@(Let (Rec binds) res) = case liftable of
-  -- Nothing is liftable, just return
-  [] -> return expr
-  -- Something can be lifted, generate a new let expression
-  _ -> change $ mkNonRecLets liftable (Let (Rec nonliftable) res)
-  where
-    -- Make a list of all the binders bound in this recursive let
-    bndrs = map fst binds
-    -- See which bindings are liftable
-    (liftable, nonliftable) = List.partition canlift binds
-    -- Any expression that does not use any of the binders in this recursive let
-    -- can be lifted into a nonrec let. It can't use its own binder either,
-    -- since that would mean the binding is self-recursive and should be in a
-    -- single bind recursive let.
-    canlift (bndr, e) = not $ expr_uses_binders bndrs e
+letrec, letrectop :: Transform
+letrec c expr@(Let (NonRec bndr val) res) = 
+  change $ Let (Rec [(bndr, val)]) res
+
 -- Leave all other expressions unchanged
-letderec c expr = return expr
+letrec c expr = return expr
 -- Perform this transform everywhere
-letderectop = everywhere ("letderec", letderec)
+letrectop = everywhere ("letrec", letrec)
 
 --------------------------------
 -- let flattening
@@ -367,10 +358,10 @@ needsInline f = do
         case norm_maybe of
           -- Noth normalizeable
           Nothing -> return Nothing 
-          Just norm -> case splitNormalized norm of
+          Just norm -> case splitNormalizedNonRep norm of
             -- The function has just a single binding, so that's simple
             -- enough to inline.
-            (args, [bind], res) -> return $ Just norm
+            (args, [bind], Var res) -> return $ Just norm
             -- More complicated function, don't inline
             _ -> return Nothing
             
@@ -805,7 +796,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [inlinedicttop, inlinetopleveltop, classopresolutiontop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, retvalsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop]
+transforms = [inlinedicttop, inlinetopleveltop, classopresolutiontop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letrectop, letremovetop, retvalsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function, or an error
 -- if it is not a known global binder.
@@ -861,14 +852,21 @@ normalizeExpr what expr = do
        return expr'
 
 -- | Split a normalized expression into the argument binders, top level
---   bindings and the result binder.
+--   bindings and the result binder. This function returns an error if
+--   the type of the expression is not representable.
 splitNormalized ::
   CoreExpr -- ^ The normalized expression
   -> ([CoreBndr], [Binding], CoreBndr)
-splitNormalized expr = (args, binds, res)
+splitNormalized expr = 
+  case splitNormalizedNonRep expr of
+    (args, binds, Var res) -> (args, binds, res)
+    _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- Split a normalized expression, whose type can be unrepresentable.
+splitNormalizedNonRep::
+  CoreExpr -- ^ The normalized expression
+  -> ([CoreBndr], [Binding], CoreExpr)
+splitNormalizedNonRep expr = (args, binds, resexpr)
   where
     (args, letexpr) = CoreSyn.collectBinders expr
     (binds, resexpr) = flattenLets letexpr
-    res = case resexpr of 
-      (Var x) -> x
-      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"