castsimpltop = everywhere ("castsimpl", castsimpl)
--------------------------------
--- Ensure that a function that just returns another function (or rather,
--- another top-level binder) is still properly normalized. This is a temporary
--- solution, we should probably integrate this pass with lambdasimpl and
--- letsimpl instead.
+-- Return value simplification
--------------------------------
+-- Ensure the return value of a function follows proper normal form. eta
+-- expansion ensures the body starts with lambda abstractions, this
+-- transformation ensures that the lambda abstractions always contain a
+-- recursive let and that, when the return value is representable, the
+-- let contains a local variable reference in its body.
+retvalsimpl c expr | all (== LambdaBody) c && not (is_lam expr) && not (is_let expr) = do
+ local_var <- Trans.lift $ is_local_var expr
+ repr <- isRepr expr
+ if not local_var && repr
+ then do
+ id <- Trans.lift $ mkBinderFor expr "res"
+ change $ Let (Rec [(id, expr)]) (Var id)
+ else
+ return expr
+
retvalsimpl c expr@(Let (Rec binds) body) | all (== LambdaBody) c = do
-- Don't extract values that are already a local variable, to prevent
-- loops with ourselves.
else
return expr
-retvalsimpl c expr | all (== LambdaBody) c && not (is_lam expr) && not (is_let expr) = do
- local_var <- Trans.lift $ is_local_var expr
- repr <- isRepr expr
- if not local_var && repr
- then do
- id <- Trans.lift $ mkBinderFor expr "res"
- change $ Let (Rec [(id, expr)]) (Var id)
- else
- return expr
-- Leave all other expressions unchanged
retvalsimpl c expr = return expr
return expr'
-- | Split a normalized expression into the argument binders, top level
--- bindings and the result binder.
+-- bindings and the result binder. This function returns an error if
+-- the type of the expression is not representable.
splitNormalized ::
CoreExpr -- ^ The normalized expression
-> ([CoreBndr], [Binding], CoreBndr)
-splitNormalized expr = (args, binds, res)
+splitNormalized expr =
+ case splitNormalizedNonRep expr of
+ (args, binds, Var res) -> (args, binds, res)
+ _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- Split a normalized expression, whose type can be unrepresentable.
+splitNormalizedNonRep::
+ CoreExpr -- ^ The normalized expression
+ -> ([CoreBndr], [Binding], CoreExpr)
+splitNormalizedNonRep expr = (args, binds, resexpr)
where
(args, letexpr) = CoreSyn.collectBinders expr
(binds, resexpr) = flattenLets letexpr
- res = case resexpr of
- (Var x) -> x
- _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"