X-Git-Url: https://git.stderr.nl/gitweb?a=blobdiff_plain;f=c%CE%BBash%2FCLasH%2FNormalize.hs;h=192ed55cf0500c1655fb9b1adba6c3a79512d596;hb=60174903a7e142bf05586c24498b7e064a7118ff;hp=647bd1902592c46db67c132a9ff744901a82b971;hpb=14c4b2bf87d936f2123c237a26503011ccace963;p=matthijs%2Fmaster-project%2Fc%CE%BBash.git diff --git "a/c\316\273ash/CLasH/Normalize.hs" "b/c\316\273ash/CLasH/Normalize.hs" index 647bd19..192ed55 100644 --- "a/c\316\273ash/CLasH/Normalize.hs" +++ "b/c\316\273ash/CLasH/Normalize.hs" @@ -4,11 +4,12 @@ -- top level function "normalize", and defines the actual transformation passes that -- are performed. -- -module CLasH.Normalize (getNormalized, normalizeExpr) where +module CLasH.Normalize (getNormalized, normalizeExpr, splitNormalized) where -- Standard modules import Debug.Trace import qualified Maybe +import qualified List import qualified "transformers" Control.Monad.Trans as Trans import qualified Control.Monad as Monad import qualified Control.Monad.Trans.Writer as Writer @@ -91,31 +92,101 @@ castprop expr = return expr castproptop = everywhere ("castprop", castprop) -------------------------------- --- let recursification +-- Cast simplification. Mostly useful for state packing and unpacking, but +-- perhaps for others as well. -------------------------------- -letrec, letrectop :: Transform -letrec (Let (NonRec b expr) res) = change $ Let (Rec [(b, expr)]) res +castsimpl, castsimpltop :: Transform +castsimpl expr@(Cast val ty) = do + -- Don't extract values that are already simpl + local_var <- Trans.lift $ is_local_var val + -- Don't extract values that are not representable, to prevent loops with + -- inlinenonrep + repr <- isRepr val + if (not local_var) && repr + then do + -- Generate a binder for the expression + id <- Trans.lift $ mkBinderFor val "castval" + -- Extract the expression + change $ Let (NonRec id val) (Cast (Var id) ty) + else + return expr +-- Leave all other expressions unchanged +castsimpl expr = return expr +-- Perform this transform everywhere +castsimpltop = everywhere ("castsimpl", castsimpl) + + +-------------------------------- +-- Lambda simplication +-------------------------------- +-- Ensure that a lambda always evaluates to a let expressions or a simple +-- variable reference. +lambdasimpl, lambdasimpltop :: Transform +-- Don't simplify a lambda that evaluates to let, since this is already +-- normal form (and would cause infinite loops). +lambdasimpl expr@(Lam _ (Let _ _)) = return expr +-- Put the of a lambda in its own binding, but not when the expression is +-- already a local variable, or not representable (to prevent loops with +-- inlinenonrep). +lambdasimpl expr@(Lam bndr res) = do + repr <- isRepr res + local_var <- Trans.lift $ is_local_var res + if not local_var && repr + then do + id <- Trans.lift $ mkBinderFor res "res" + change $ Lam bndr (Let (NonRec id res) (Var id)) + else + -- If the result is already a local var or not representable, don't + -- extract it. + return expr + +-- Leave all other expressions unchanged +lambdasimpl expr = return expr +-- Perform this transform everywhere +lambdasimpltop = everywhere ("lambdasimpl", lambdasimpl) + +-------------------------------- +-- let derecursification +-------------------------------- +letderec, letderectop :: Transform +letderec expr@(Let (Rec binds) res) = case liftable of + -- Nothing is liftable, just return + [] -> return expr + -- Something can be lifted, generate a new let expression + _ -> change $ mkNonRecLets liftable (Let (Rec nonliftable) res) + where + -- Make a list of all the binders bound in this recursive let + bndrs = map fst binds + -- See which bindings are liftable + (liftable, nonliftable) = List.partition canlift binds + -- Any expression that does not use any of the binders in this recursive let + -- can be lifted into a nonrec let. It can't use its own binder either, + -- since that would mean the binding is self-recursive and should be in a + -- single bind recursive let. + canlift (bndr, e) = not $ expr_uses_binders bndrs e -- Leave all other expressions unchanged -letrec expr = return expr +letderec expr = return expr -- Perform this transform everywhere -letrectop = everywhere ("letrec", letrec) +letderectop = everywhere ("letderec", letderec) -------------------------------- -- let simplification -------------------------------- letsimpl, letsimpltop :: Transform +-- Don't simplify a let that evaluates to another let, since this is already +-- normal form (and would cause infinite loops with letflat below). +letsimpl expr@(Let _ (Let _ _)) = return expr -- Put the "in ..." value of a let in its own binding, but not when the -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep). -letsimpl expr@(Let (Rec binds) res) = do +letsimpl expr@(Let binds res) = do repr <- isRepr res local_var <- Trans.lift $ is_local_var res if not local_var && repr then do -- If the result is not a local var already (to prevent loops with -- ourselves), extract it. - id <- Trans.lift $ mkInternalVar "foo" (CoreUtils.exprType res) - let bind = (id, res) - change $ Let (Rec (bind:binds)) (Var id) + id <- Trans.lift $ mkBinderFor res "foo" + change $ Let binds (Let (NonRec id res) (Var id)) else -- If the result is already a local var, don't extract it. return expr @@ -128,13 +199,18 @@ letsimpltop = everywhere ("letsimpl", letsimpl) -------------------------------- -- let flattening -------------------------------- +-- Takes a let that binds another let, and turns that into two nested lets. +-- e.g., from: +-- let b = (let b' = expr' in res') in res +-- to: +-- let b' = expr' in (let b = res' in res) letflat, letflattop :: Transform +-- Turn a nonrec let that binds a let into two nested lets. +letflat (Let (NonRec b (Let binds res')) res) = + change $ Let binds (Let (NonRec b res') res) letflat (Let (Rec binds) expr) = do - -- Turn each binding into a list of bindings (possibly containing just one - -- element, of course) - bindss <- Monad.mapM flatbind binds - -- Concat all the bindings - let binds' = concat bindss + -- Flatten each binding. + binds' <- Utils.concatM $ Monad.mapM flatbind binds -- Return the new let. We don't use change here, since possibly nothing has -- changed. If anything has changed, flatbind has already flagged that -- change. @@ -150,17 +226,33 @@ letflat expr = return expr -- Perform this transform everywhere letflattop = everywhere ("letflat", letflat) +-------------------------------- +-- empty let removal +-------------------------------- +-- Remove empty (recursive) lets +letremove, letremovetop :: Transform +letremove (Let (Rec []) res) = change $ res +-- Leave all other expressions unchanged +letremove expr = return expr +-- Perform this transform everywhere +letremovetop = everywhere ("letremove", letremove) + -------------------------------- -- Simple let binding removal -------------------------------- -- Remove a = b bindings from let expressions everywhere -letremovetop :: Transform -letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e)) +letremovesimpletop :: Transform +letremovesimpletop = everywhere ("letremovesimple", inlinebind (\(b, e) -> Trans.lift $ is_local_var e)) -------------------------------- -- Unused let binding removal -------------------------------- letremoveunused, letremoveunusedtop :: Transform +letremoveunused expr@(Let (NonRec b bound) res) = do + let used = expr_uses_binders [b] res + if used + then return expr + else change res letremoveunused expr@(Let (Rec binds) res) = do -- Filter out all unused binds. let binds' = filter dobind binds @@ -170,11 +262,44 @@ letremoveunused expr@(Let (Rec binds) res) = do bound_exprs = map snd binds -- For each bind check if the bind is used by res or any of the bound -- expressions - dobind (bndr, _) = not $ any (expr_uses_binders [bndr]) (res:bound_exprs) + dobind (bndr, _) = any (expr_uses_binders [bndr]) (res:bound_exprs) -- Leave all other expressions unchanged letremoveunused expr = return expr letremoveunusedtop = everywhere ("letremoveunused", letremoveunused) +{- +-------------------------------- +-- Identical let binding merging +-------------------------------- +-- Merge two bindings in a let if they are identical +-- TODO: We would very much like to use GHC's CSE module for this, but that +-- doesn't track if something changed or not, so we can't use it properly. +letmerge, letmergetop :: Transform +letmerge expr@(Let _ _) = do + let (binds, res) = flattenLets expr + binds' <- domerge binds + return $ mkNonRecLets binds' res + where + domerge :: [(CoreBndr, CoreExpr)] -> TransformMonad [(CoreBndr, CoreExpr)] + domerge [] = return [] + domerge (e:es) = do + es' <- mapM (mergebinds e) es + es'' <- domerge es' + return (e:es'') + + -- Uses the second bind to simplify the second bind, if applicable. + mergebinds :: (CoreBndr, CoreExpr) -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr) + mergebinds (b1, e1) (b2, e2) + -- Identical expressions? Replace the second binding with a reference to + -- the first binder. + | CoreUtils.cheapEqExpr e1 e2 = change $ (b2, Var b1) + -- Different expressions? Don't change + | otherwise = return (b2, e2) +-- Leave all other expressions unchanged +letmerge expr = return expr +letmergetop = everywhere ("letmerge", letmerge) +-} + -------------------------------- -- Function inlining -------------------------------- @@ -207,8 +332,8 @@ scrutsimpl expr@(Case scrut b ty alts) = do repr <- isRepr scrut if repr then do - id <- Trans.lift $ mkInternalVar "scrut" (CoreUtils.exprType scrut) - change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts) + id <- Trans.lift $ mkBinderFor scrut "scrut" + change $ Let (NonRec id scrut) (Case (Var id) b ty alts) else return expr -- Leave all other expressions unchanged @@ -234,7 +359,7 @@ casesimpl expr@(Case scrut b ty alts) = do (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts let bindings = concat bindingss -- Replace the case with a let with bindings and a case - let newlet = (Let (Rec bindings) (Case scrut b ty alts')) + let newlet = mkNonRecLets bindings (Case scrut b ty alts') -- If there are no non-wild binders, or this case is already a simple -- selector (i.e., a single alt with exactly one binding), already a simple -- selector altan no bindings (i.e., no wild binders in the original case), @@ -257,7 +382,7 @@ casesimpl expr@(Case scrut b ty alts) = do (exprbinding_maybe, expr') <- doexpr expr uses_bndrs -- Create a new alternative let newalt = (con, newbndrs, expr') - let bindings = Maybe.catMaybes (exprbinding_maybe : bindings_maybe) + let bindings = Maybe.catMaybes (bindings_maybe ++ [exprbinding_maybe]) return (bindings, newalt) where -- Make wild alternatives for each binder @@ -301,7 +426,7 @@ casesimpl expr@(Case scrut b ty alts) = do -- prevent loops with inlinenonrep). if (not uses_bndrs) && (not local_var) && repr then do - id <- Trans.lift $ mkInternalVar "caseval" (CoreUtils.exprType expr) + id <- Trans.lift $ mkBinderFor expr "caseval" -- We don't flag a change here, since casevalsimpl will do that above -- based on Just we return here. return $ (Just (id, expr), Var id) @@ -341,8 +466,8 @@ appsimpl expr@(App f arg) = do local_var <- Trans.lift $ is_local_var arg if repr && not local_var then do -- Extract representable arguments - id <- Trans.lift $ mkInternalVar "arg" (CoreUtils.exprType arg) - change $ Let (Rec [(id, arg)]) (App f (Var id)) + id <- Trans.lift $ mkBinderFor arg "arg" + change $ Let (NonRec id arg) (App f (Var id)) else -- Leave non-representable arguments unchanged return expr -- Leave all other expressions unchanged @@ -491,7 +616,7 @@ funextracttop = everywhere ("funextract", funextract) -- What transforms to run? -transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop] +transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop] -- | Returns the normalized version of the given function. getNormalized :: @@ -515,15 +640,11 @@ normalizeExpr :: -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression normalizeExpr what expr = do - -- Introduce an empty Let at the top level, so there will always be - -- a let in the expression (none of the transformations will remove - -- the last let). - let expr' = Let (Rec []) expr -- Normalize this expression - trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return () - expr'' <- dotransforms transforms expr' - trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return () - return expr'' + trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr ) ++ "\n") $ return () + expr' <- dotransforms transforms expr + trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return () + return expr' -- | Get the value that is bound to the given binder at top level. Fails when -- there is no such binding. @@ -535,3 +656,16 @@ getBinding bndr = Utils.makeCached bndr tsBindings $ do -- If the binding isn't in the "cache" (bindings map), then we can't create -- it out of thin air, so return an error. error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr + +-- | Split a normalized expression into the argument binders, top level +-- bindings and the result binder. +splitNormalized :: + CoreExpr -- ^ The normalized expression + -> ([CoreBndr], [Binding], CoreBndr) +splitNormalized expr = (args, binds, res) + where + (args, letexpr) = CoreSyn.collectBinders expr + (binds, resexpr) = flattenLets letexpr + res = case resexpr of + (Var x) -> x + _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"