X-Git-Url: https://git.stderr.nl/gitweb?a=blobdiff_plain;f=c%CE%BBash%2FCLasH%2FNormalize.hs;h=7fb0dc235d2a60d8a3e3798b8a655e9f6f3218f7;hb=10dfe589f40e65d51ca1585beecf00ae85169cae;hp=8b35bb986bd4265df2ec58e6988fc21f04fe64de;hpb=d42006e5be8a3d6e3970b84767856b77d00f7f73;p=matthijs%2Fmaster-project%2Fc%CE%BBash.git diff --git "a/c\316\273ash/CLasH/Normalize.hs" "b/c\316\273ash/CLasH/Normalize.hs" index 8b35bb9..7fb0dc2 100644 --- "a/c\316\273ash/CLasH/Normalize.hs" +++ "b/c\316\273ash/CLasH/Normalize.hs" @@ -4,7 +4,7 @@ -- top level function "normalize", and defines the actual transformation passes that -- are performed. -- -module CLasH.Normalize (getNormalized, normalizeExpr) where +module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where -- Standard modules import Debug.Trace @@ -90,6 +90,30 @@ castprop expr = return expr -- Perform this transform everywhere castproptop = everywhere ("castprop", castprop) +-------------------------------- +-- Cast simplification. Mostly useful for state packing and unpacking, but +-- perhaps for others as well. +-------------------------------- +castsimpl, castsimpltop :: Transform +castsimpl expr@(Cast val ty) = do + -- Don't extract values that are already simpl + local_var <- Trans.lift $ is_local_var val + -- Don't extract values that are not representable, to prevent loops with + -- inlinenonrep + repr <- isRepr val + if (not local_var) && repr + then do + -- Generate a binder for the expression + id <- Trans.lift $ mkBinderFor val "castval" + -- Extract the expression + change $ Let (Rec [(id, val)]) (Cast (Var id) ty) + else + return expr +-- Leave all other expressions unchanged +castsimpl expr = return expr +-- Perform this transform everywhere +castsimpltop = everywhere ("castsimpl", castsimpl) + -------------------------------- -- let recursification -------------------------------- @@ -113,7 +137,7 @@ letsimpl expr@(Let (Rec binds) res) = do then do -- If the result is not a local var already (to prevent loops with -- ourselves), extract it. - id <- Trans.lift $ mkInternalVar "foo" (CoreUtils.exprType res) + id <- Trans.lift $ mkBinderFor res "foo" let bind = (id, res) change $ Let (Rec (bind:binds)) (Var id) else @@ -157,6 +181,54 @@ letflattop = everywhere ("letflat", letflat) letremovetop :: Transform letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e)) +-------------------------------- +-- Unused let binding removal +-------------------------------- +letremoveunused, letremoveunusedtop :: Transform +letremoveunused expr@(Let (Rec binds) res) = do + -- Filter out all unused binds. + let binds' = filter dobind binds + -- Only set the changed flag if binds got removed + changeif (length binds' /= length binds) (Let (Rec binds') res) + where + bound_exprs = map snd binds + -- For each bind check if the bind is used by res or any of the bound + -- expressions + dobind (bndr, _) = any (expr_uses_binders [bndr]) (res:bound_exprs) +-- Leave all other expressions unchanged +letremoveunused expr = return expr +letremoveunusedtop = everywhere ("letremoveunused", letremoveunused) + +-------------------------------- +-- Identical let binding merging +-------------------------------- +-- Merge two bindings in a let if they are identical +-- TODO: We would very much like to use GHC's CSE module for this, but that +-- doesn't track if something changed or not, so we can't use it properly. +letmerge, letmergetop :: Transform +letmerge expr@(Let (Rec binds) res) = do + binds' <- domerge binds + return (Let (Rec binds') res) + where + domerge :: [(CoreBndr, CoreExpr)] -> TransformMonad [(CoreBndr, CoreExpr)] + domerge [] = return [] + domerge (e:es) = do + es' <- mapM (mergebinds e) es + es'' <- domerge es' + return (e:es'') + + -- Uses the second bind to simplify the second bind, if applicable. + mergebinds :: (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr) + mergebinds (b1, e1) (b2, e2) + -- Identical expressions? Replace the second binding with a reference to + -- the first binder. + | CoreUtils.cheapEqExpr e1 e2 = change $ (b2, Var b1) + -- Different expressions? Don't change + | otherwise = return (b2, e2) +-- Leave all other expressions unchanged +letmerge expr = return expr +letmergetop = everywhere ("letmerge", letmerge) + -------------------------------- -- Function inlining -------------------------------- @@ -189,7 +261,7 @@ scrutsimpl expr@(Case scrut b ty alts) = do repr <- isRepr scrut if repr then do - id <- Trans.lift $ mkInternalVar "scrut" (CoreUtils.exprType scrut) + id <- Trans.lift $ mkBinderFor scrut "scrut" change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts) else return expr @@ -283,7 +355,7 @@ casesimpl expr@(Case scrut b ty alts) = do -- prevent loops with inlinenonrep). if (not uses_bndrs) && (not local_var) && repr then do - id <- Trans.lift $ mkInternalVar "caseval" (CoreUtils.exprType expr) + id <- Trans.lift $ mkBinderFor expr "caseval" -- We don't flag a change here, since casevalsimpl will do that above -- based on Just we return here. return $ (Just (id, expr), Var id) @@ -323,7 +395,7 @@ appsimpl expr@(App f arg) = do local_var <- Trans.lift $ is_local_var arg if repr && not local_var then do -- Extract representable arguments - id <- Trans.lift $ mkInternalVar "arg" (CoreUtils.exprType arg) + id <- Trans.lift $ mkBinderFor arg "arg" change $ Let (Rec [(id, arg)]) (App f (Var id)) else -- Leave non-representable arguments unchanged return expr @@ -473,7 +545,7 @@ funextracttop = everywhere ("funextract", funextract) -- What transforms to run? -transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop] +transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letmergetop, letremoveunusedtop, castsimpltop] -- | Returns the normalized version of the given function. getNormalized :: @@ -502,9 +574,9 @@ normalizeExpr what expr = do -- the last let). let expr' = Let (Rec []) expr -- Normalize this expression - trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return () + trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return () expr'' <- dotransforms transforms expr' - trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return () + trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr'')) $ return () return expr'' -- | Get the value that is bound to the given binder at top level. Fails when @@ -517,3 +589,15 @@ getBinding bndr = Utils.makeCached bndr tsBindings $ do -- If the binding isn't in the "cache" (bindings map), then we can't create -- it out of thin air, so return an error. error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr + +-- | Split a normalized expression into the argument binders, top level +-- bindings and the result binder. +splitNormalized :: + CoreExpr -- ^ The normalized expression + -> ([CoreBndr], [Binding], CoreBndr) +splitNormalized expr = + case letexpr of + (Let (Rec binds) (Var res)) -> (args, binds, res) + _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n" + where + (args, letexpr) = CoreSyn.collectBinders expr