-- top level function "normalize", and defines the actual transformation passes that
-- are performed.
--
-module CLasH.Normalize (getNormalized, normalizeExpr) where
+module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where
-- Standard modules
import Debug.Trace
import qualified Maybe
+import qualified List
import qualified "transformers" Control.Monad.Trans as Trans
import qualified Control.Monad as Monad
import qualified Control.Monad.Trans.Writer as Writer
castproptop = everywhere ("castprop", castprop)
--------------------------------
--- let recursification
+-- Cast simplification. Mostly useful for state packing and unpacking, but
+-- perhaps for others as well.
--------------------------------
-letrec, letrectop :: Transform
-letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res
+castsimpl, castsimpltop :: Transform
+castsimpl expr@(Cast val ty) = do
+ -- Don't extract values that are already simpl
+ local_var <- Trans.lift $ is_local_var val
+ -- Don't extract values that are not representable, to prevent loops with
+ -- inlinenonrep
+ repr <- isRepr val
+ if (not local_var) && repr
+ then do
+ -- Generate a binder for the expression
+ id <- Trans.lift $ mkBinderFor val "castval"
+ -- Extract the expression
+ change $ Let (Rec [(id, val)]) (Cast (Var id) ty)
+ else
+ return expr
+-- Leave all other expressions unchanged
+castsimpl expr = return expr
+-- Perform this transform everywhere
+castsimpltop = everywhere ("castsimpl", castsimpl)
+
+--------------------------------
+-- let derecursification
+--------------------------------
+letderec, letderectop :: Transform
+letderec expr@(Let (Rec binds) res) = case liftable of
+ -- Nothing is liftable, just return
+ [] -> return expr
+ -- Something can be lifted, generate a new let expression
+ _ -> change $ MkCore.mkCoreLets newbinds res
+ where
+ -- Make a list of all the binders bound in this recursive let
+ bndrs = map fst binds
+ -- See which bindings are liftable
+ (liftable, nonliftable) = List.partition canlift binds
+ -- Create nonrec bindings for each liftable binding and a single recursive
+ -- binding for all others
+ newbinds = (map (uncurry NonRec) liftable) ++ [Rec nonliftable]
+ -- Any expression that does not use any of the binders in this recursive let
+ -- can be lifted into a nonrec let. It can't use its own binder either,
+ -- since that would mean the binding is self-recursive and should be in a
+ -- single bind recursive let.
+ canlift (bndr, e) = not $ expr_uses_binders bndrs e
-- Leave all other expressions unchanged
-letrec expr = return expr
+letderec expr = return expr
-- Perform this transform everywhere
-letrectop = everywhere ("letrec", letrec)
+letderectop = everywhere ("letderec", letderec)
--------------------------------
-- let simplification
then do
-- If the result is not a local var already (to prevent loops with
-- ourselves), extract it.
- id <- Trans.lift $ mkInternalVar "foo" (CoreUtils.exprType res)
+ id <- Trans.lift $ mkBinderFor res "foo"
let bind = (id, res)
change $ Let (Rec (bind:binds)) (Var id)
else
letremoveunused expr = return expr
letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
+--------------------------------
+-- Identical let binding merging
+--------------------------------
+-- Merge two bindings in a let if they are identical
+-- TODO: We would very much like to use GHC's CSE module for this, but that
+-- doesn't track if something changed or not, so we can't use it properly.
+letmerge, letmergetop :: Transform
+letmerge expr@(Let (Rec binds) res) = do
+ binds' <- domerge binds
+ return (Let (Rec binds') res)
+ where
+ domerge :: [(CoreBndr, CoreExpr)] -> TransformMonad [(CoreBndr, CoreExpr)]
+ domerge [] = return []
+ domerge (e:es) = do
+ es' <- mapM (mergebinds e) es
+ es'' <- domerge es'
+ return (e:es'')
+
+ -- Uses the second bind to simplify the second bind, if applicable.
+ mergebinds :: (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
+ mergebinds (b1, e1) (b2, e2)
+ -- Identical expressions? Replace the second binding with a reference to
+ -- the first binder.
+ | CoreUtils.cheapEqExpr e1 e2 = change $ (b2, Var b1)
+ -- Different expressions? Don't change
+ | otherwise = return (b2, e2)
+-- Leave all other expressions unchanged
+letmerge expr = return expr
+letmergetop = everywhere ("letmerge", letmerge)
+
--------------------------------
-- Function inlining
--------------------------------
repr <- isRepr scrut
if repr
then do
- id <- Trans.lift $ mkInternalVar "scrut" (CoreUtils.exprType scrut)
+ id <- Trans.lift $ mkBinderFor scrut "scrut"
change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
else
return expr
-- prevent loops with inlinenonrep).
if (not uses_bndrs) && (not local_var) && repr
then do
- id <- Trans.lift $ mkInternalVar "caseval" (CoreUtils.exprType expr)
+ id <- Trans.lift $ mkBinderFor expr "caseval"
-- We don't flag a change here, since casevalsimpl will do that above
-- based on Just we return here.
return $ (Just (id, expr), Var id)
local_var <- Trans.lift $ is_local_var arg
if repr && not local_var
then do -- Extract representable arguments
- id <- Trans.lift $ mkInternalVar "arg" (CoreUtils.exprType arg)
+ id <- Trans.lift $ mkBinderFor arg "arg"
change $ Let (Rec [(id, arg)]) (App f (Var id))
else -- Leave non-representable arguments unchanged
return expr
-- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letderectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop]
-- | Returns the normalized version of the given function.
getNormalized ::
-- the last let).
let expr' = Let (Rec []) expr
-- Normalize this expression
- trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
+ trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
expr'' <- dotransforms transforms expr'
- trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
+ trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
return expr''
-- | Get the value that is bound to the given binder at top level. Fails when
-- If the binding isn't in the "cache" (bindings map), then we can't create
-- it out of thin air, so return an error.
error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
+
+-- | Split a normalized expression into the argument binders, top level
+-- bindings and the result binder.
+splitNormalized ::
+ CoreExpr -- ^ The normalized expression
+ -> ([CoreBndr], [Binding], CoreBndr)
+splitNormalized expr = (args, binds, res)
+ where
+ (args, letexpr) = CoreSyn.collectBinders expr
+ (binds, resexpr) = flattenLets letexpr
+ res = case resexpr of
+ (Var x) -> x
+ _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- | Flattens nested lets into a single list of bindings. The expression
+-- passed does not have to be a let expression, if it isn't an empty list of
+-- bindings is returned.
+flattenLets ::
+ CoreExpr -- ^ The expression to flatten.
+ -> ([Binding], CoreExpr) -- ^ The bindings and resulting expression.
+flattenLets (Let binds expr) =
+ (bindings ++ bindings', expr')
+ where
+ -- Recursively flatten the contained expression
+ (bindings', expr') =flattenLets expr
+ -- Flatten our own bindings to remove the Rec / NonRec constructors
+ bindings = CoreSyn.flattenBinds [binds]
+flattenLets expr = ([], expr)