From 3deb1d21f696f8495cd99345c9677210e2a2fc79 Mon Sep 17 00:00:00 2001 From: Matthijs Kooijman Date: Tue, 30 Mar 2010 14:25:54 +0200 Subject: [PATCH] Pass the context in which an expression occurs to each transformation. Currently, the only context supported is "Other" and contexts are only passed around and not yet used anywhere. --- "c\316\273ash/CLasH/Normalize.hs" | 114 +++++++++--------- .../CLasH/Normalize/NormalizeTools.hs" | 71 ++++++----- .../CLasH/Normalize/NormalizeTypes.hs" | 5 +- 3 files changed, 96 insertions(+), 94 deletions(-) diff --git "a/c\316\273ash/CLasH/Normalize.hs" "b/c\316\273ash/CLasH/Normalize.hs" index 31474d5..fa6ae8c 100644 --- "a/c\316\273ash/CLasH/Normalize.hs" +++ "b/c\316\273ash/CLasH/Normalize.hs" @@ -48,12 +48,12 @@ import CLasH.Utils.Pretty -- η abstraction -------------------------------- eta, etatop :: Transform -eta expr | is_fun expr && not (is_lam expr) = do +eta c expr | is_fun expr && not (is_lam expr) = do let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr id <- Trans.lift $ mkInternalVar "param" arg_ty change (Lam id (App expr (Var id))) -- Leave all other expressions unchanged -eta e = return e +eta c e = return e etatop = notappargs ("eta", eta) -------------------------------- @@ -62,17 +62,17 @@ etatop = notappargs ("eta", eta) beta, betatop :: Transform -- Substitute arg for x in expr. For value lambda's, also clone before -- substitution. -beta (App (Lam x expr) arg) | CoreSyn.isTyVar x = setChanged >> substitute x arg expr - | otherwise = setChanged >> substitute_clone x arg expr +beta c (App (Lam x expr) arg) | CoreSyn.isTyVar x = setChanged >> substitute x arg c expr + | otherwise = setChanged >> substitute_clone x arg c expr -- Propagate the application into the let -beta (App (Let binds expr) arg) = change $ Let binds (App expr arg) +beta c (App (Let binds expr) arg) = change $ Let binds (App expr arg) -- Propagate the application into each of the alternatives -beta (App (Case scrut b ty alts) arg) = change $ Case scrut b ty' alts' +beta c (App (Case scrut b ty alts) arg) = change $ Case scrut b ty' alts' where alts' = map (\(con, bndrs, expr) -> (con, bndrs, (App expr arg))) alts ty' = CoreUtils.applyTypeToArg ty arg -- Leave all other expressions unchanged -beta expr = return expr +beta c expr = return expr -- Perform this transform everywhere betatop = everywhere ("beta", beta) @@ -81,12 +81,12 @@ betatop = everywhere ("beta", beta) -------------------------------- -- Try to move casts as much downward as possible. castprop, castproptop :: Transform -castprop (Cast (Let binds expr) ty) = change $ Let binds (Cast expr ty) -castprop expr@(Cast (Case scrut b _ alts) ty) = change (Case scrut b ty alts') +castprop c (Cast (Let binds expr) ty) = change $ Let binds (Cast expr ty) +castprop c expr@(Cast (Case scrut b _ alts) ty) = change (Case scrut b ty alts') where alts' = map (\(con, bndrs, expr) -> (con, bndrs, (Cast expr ty))) alts -- Leave all other expressions unchanged -castprop expr = return expr +castprop c expr = return expr -- Perform this transform everywhere castproptop = everywhere ("castprop", castprop) @@ -95,7 +95,7 @@ castproptop = everywhere ("castprop", castprop) -- perhaps for others as well. -------------------------------- castsimpl, castsimpltop :: Transform -castsimpl expr@(Cast val ty) = do +castsimpl c expr@(Cast val ty) = do -- Don't extract values that are already simpl local_var <- Trans.lift $ is_local_var val -- Don't extract values that are not representable, to prevent loops with @@ -110,7 +110,7 @@ castsimpl expr@(Cast val ty) = do else return expr -- Leave all other expressions unchanged -castsimpl expr = return expr +castsimpl c expr = return expr -- Perform this transform everywhere castsimpltop = everywhere ("castsimpl", castsimpl) @@ -123,11 +123,11 @@ castsimpltop = everywhere ("castsimpl", castsimpl) lambdasimpl, lambdasimpltop :: Transform -- Don't simplify a lambda that evaluates to let, since this is already -- normal form (and would cause infinite loops). -lambdasimpl expr@(Lam _ (Let _ _)) = return expr +lambdasimpl c expr@(Lam _ (Let _ _)) = return expr -- Put the of a lambda in its own binding, but not when the expression is -- already a local variable, or not representable (to prevent loops with -- inlinenonrep). -lambdasimpl expr@(Lam bndr res) = do +lambdasimpl c expr@(Lam bndr res) = do repr <- isRepr res local_var <- Trans.lift $ is_local_var res if not local_var && repr @@ -140,7 +140,7 @@ lambdasimpl expr@(Lam bndr res) = do return expr -- Leave all other expressions unchanged -lambdasimpl expr = return expr +lambdasimpl c expr = return expr -- Perform this transform everywhere lambdasimpltop = everywhere ("lambdasimpl", lambdasimpl) @@ -148,7 +148,7 @@ lambdasimpltop = everywhere ("lambdasimpl", lambdasimpl) -- let derecursification -------------------------------- letderec, letderectop :: Transform -letderec expr@(Let (Rec binds) res) = case liftable of +letderec c expr@(Let (Rec binds) res) = case liftable of -- Nothing is liftable, just return [] -> return expr -- Something can be lifted, generate a new let expression @@ -164,7 +164,7 @@ letderec expr@(Let (Rec binds) res) = case liftable of -- single bind recursive let. canlift (bndr, e) = not $ expr_uses_binders bndrs e -- Leave all other expressions unchanged -letderec expr = return expr +letderec c expr = return expr -- Perform this transform everywhere letderectop = everywhere ("letderec", letderec) @@ -174,10 +174,10 @@ letderectop = everywhere ("letderec", letderec) letsimpl, letsimpltop :: Transform -- Don't simplify a let that evaluates to another let, since this is already -- normal form (and would cause infinite loops with letflat below). -letsimpl expr@(Let _ (Let _ _)) = return expr +letsimpl c expr@(Let _ (Let _ _)) = return expr -- Put the "in ..." value of a let in its own binding, but not when the -- expression is already a local variable, or not representable (to prevent loops with inlinenonrep). -letsimpl expr@(Let binds res) = do +letsimpl c expr@(Let binds res) = do repr <- isRepr res local_var <- Trans.lift $ is_local_var res if not local_var && repr @@ -191,7 +191,7 @@ letsimpl expr@(Let binds res) = do return expr -- Leave all other expressions unchanged -letsimpl expr = return expr +letsimpl c expr = return expr -- Perform this transform everywhere letsimpltop = everywhere ("letsimpl", letsimpl) @@ -205,9 +205,9 @@ letsimpltop = everywhere ("letsimpl", letsimpl) -- let b' = expr' in (let b = res' in res) letflat, letflattop :: Transform -- Turn a nonrec let that binds a let into two nested lets. -letflat (Let (NonRec b (Let binds res')) res) = +letflat c (Let (NonRec b (Let binds res')) res) = change $ Let binds (Let (NonRec b res') res) -letflat (Let (Rec binds) expr) = do +letflat c (Let (Rec binds) expr) = do -- Flatten each binding. binds' <- Utils.concatM $ Monad.mapM flatbind binds -- Return the new let. We don't use change here, since possibly nothing has @@ -222,7 +222,7 @@ letflat (Let (Rec binds) expr) = do flatbind (b, Let (NonRec b' expr') expr) = change [(b, expr), (b', expr')] flatbind (b, expr) = return [(b, expr)] -- Leave all other expressions unchanged -letflat expr = return expr +letflat c expr = return expr -- Perform this transform everywhere letflattop = everywhere ("letflat", letflat) @@ -231,9 +231,9 @@ letflattop = everywhere ("letflat", letflat) -------------------------------- -- Remove empty (recursive) lets letremove, letremovetop :: Transform -letremove (Let (Rec []) res) = change res +letremove c (Let (Rec []) res) = change res -- Leave all other expressions unchanged -letremove expr = return expr +letremove c expr = return expr -- Perform this transform everywhere letremovetop = everywhere ("letremove", letremove) @@ -248,12 +248,12 @@ letremovesimpletop = everywhere ("letremovesimple", inlinebind (\(b, e) -> Trans -- Unused let binding removal -------------------------------- letremoveunused, letremoveunusedtop :: Transform -letremoveunused expr@(Let (NonRec b bound) res) = do +letremoveunused c expr@(Let (NonRec b bound) res) = do let used = expr_uses_binders [b] res if used then return expr else change res -letremoveunused expr@(Let (Rec binds) res) = do +letremoveunused c expr@(Let (Rec binds) res) = do -- Filter out all unused binds. let binds' = filter dobind binds -- Only set the changed flag if binds got removed @@ -264,7 +264,7 @@ letremoveunused expr@(Let (Rec binds) res) = do -- expressions dobind (bndr, _) = any (expr_uses_binders [bndr]) (res:bound_exprs) -- Leave all other expressions unchanged -letremoveunused expr = return expr +letremoveunused c expr = return expr letremoveunusedtop = everywhere ("letremoveunused", letremoveunused) {- @@ -275,7 +275,7 @@ letremoveunusedtop = everywhere ("letremoveunused", letremoveunused) -- TODO: We would very much like to use GHC's CSE module for this, but that -- doesn't track if something changed or not, so we can't use it properly. letmerge, letmergetop :: Transform -letmerge expr@(Let _ _) = do +letmerge c expr@(Let _ _) = do let (binds, res) = flattenLets expr binds' <- domerge binds return $ mkNonRecLets binds' res @@ -296,7 +296,7 @@ letmerge expr@(Let _ _) = do -- Different expressions? Don't change | otherwise = return (b2, e2) -- Leave all other expressions unchanged -letmerge expr = return expr +letmerge c expr = return expr letmergetop = everywhere ("letmerge", letmerge) -} @@ -344,10 +344,10 @@ inlinetoplevel, inlinetopleveltop :: Transform -- an infinite loop in transformation. Not inlining == is really a hack, -- but for now it keeps things working with the most common symptom of -- this problem. -inlinetoplevel expr@(Var f) | Name.getOccString f `elem` ["==", "/="] = return expr +inlinetoplevel c expr@(Var f) | Name.getOccString f `elem` ["==", "/="] = return expr -- Any system name is candidate for inlining. Never inline user-defined -- functions, to preserve structure. -inlinetoplevel expr@(Var f) | not $ isUserDefined f = do +inlinetoplevel c expr@(Var f) | not $ isUserDefined f = do body_maybe <- needsInline f case body_maybe of Just body -> do @@ -360,7 +360,7 @@ inlinetoplevel expr@(Var f) | not $ isUserDefined f = do -- Leave all other expressions unchanged -inlinetoplevel expr = return expr +inlinetoplevel c expr = return expr inlinetopleveltop = everywhere ("inlinetoplevel", inlinetoplevel) -- | Does the given binder need to be inlined? If so, return the body to @@ -393,14 +393,14 @@ needsInline f = do -------------------------------- -- Inline all top level dictionaries, so we can use them to resolve -- class methods based on the dictionary passed. -inlinedict expr@(Var f) | Id.isDictId f = do +inlinedict c expr@(Var f) | Id.isDictId f = do body_maybe <- Trans.lift $ getGlobalBind f case body_maybe of Nothing -> return expr Just body -> change body -- Leave all other expressions unchanged -inlinedict expr = return expr +inlinedict c expr = return expr inlinedicttop = everywhere ("inlinedict", inlinedict) -------------------------------- @@ -438,7 +438,7 @@ inlinedicttop = everywhere ("inlinedict", inlinedict) -- using $con2tag functions to translate a datacon to an int and compare -- that with GHC.Prim.==# . Better to avoid that for now. classopresolution, classopresolutiontop :: Transform -classopresolution expr@(App (App (Var sel) ty) dict) | not is_builtin = +classopresolution c expr@(App (App (Var sel) ty) dict) | not is_builtin = case Id.isClassOpId_maybe sel of -- Not a class op selector Nothing -> return expr @@ -463,7 +463,7 @@ classopresolution expr@(App (App (Var sel) ty) dict) | not is_builtin = is_builtin = elem (Name.getOccString sel) builtinIds -- Leave all other expressions unchanged -classopresolution expr = return expr +classopresolution c expr = return expr -- Perform this transform everywhere classopresolutiontop = everywhere ("classopresolution", classopresolution) @@ -472,12 +472,12 @@ classopresolutiontop = everywhere ("classopresolution", classopresolution) -------------------------------- scrutsimpl,scrutsimpltop :: Transform -- Don't touch scrutinees that are already simple -scrutsimpl expr@(Case (Var _) _ _ _) = return expr +scrutsimpl c expr@(Case (Var _) _ _ _) = return expr -- Replace all other cases with a let that binds the scrutinee and a new -- simple scrutinee, but only when the scrutinee is representable (to prevent -- loops with inlinenonrep, though I don't think a non-representable scrutinee -- will be supported anyway...) -scrutsimpl expr@(Case scrut b ty alts) = do +scrutsimpl c expr@(Case scrut b ty alts) = do repr <- isRepr scrut if repr then do @@ -486,7 +486,7 @@ scrutsimpl expr@(Case scrut b ty alts) = do else return expr -- Leave all other expressions unchanged -scrutsimpl expr = return expr +scrutsimpl c expr = return expr -- Perform this transform everywhere scrutsimpltop = everywhere ("scrutsimpl", scrutsimpl) @@ -501,18 +501,18 @@ scrutsimpltop = everywhere ("scrutsimpl", scrutsimpl) scrutbndrremove, scrutbndrremovetop :: Transform -- If the scrutinee is already simple, and the bndr is not wild yet, replace -- all occurences of the binder with the scrutinee variable. -scrutbndrremove (Case (Var scrut) bndr ty alts) | bndr_used = do +scrutbndrremove c (Case (Var scrut) bndr ty alts) | bndr_used = do alts' <- mapM subs_bndr alts change $ Case (Var scrut) wild ty alts' where is_used (_, _, expr) = expr_uses_binders [bndr] expr bndr_used = or $ map is_used alts subs_bndr (con, bndrs, expr) = do - expr' <- substitute bndr (Var scrut) expr + expr' <- substitute bndr (Var scrut) c expr return (con, bndrs, expr') wild = MkCore.mkWildBinder (Id.idType bndr) -- Leave all other expressions unchanged -scrutbndrremove expr = return expr +scrutbndrremove c expr = return expr scrutbndrremovetop = everywhere ("scrutbndrremove", scrutbndrremove) -------------------------------- @@ -522,14 +522,14 @@ casesimpl, casesimpltop :: Transform -- This is already a selector case (or, if x does not appear in bndrs, a very -- simple case statement that will be removed by caseremove below). Just leave -- it be. -casesimpl expr@(Case scrut b ty [(con, bndrs, Var x)]) = return expr +casesimpl c expr@(Case scrut b ty [(con, bndrs, Var x)]) = return expr -- Make sure that all case alternatives have only wild binders and simple -- expressions. -- This is done by creating a new let binding for each non-wild binder, which -- is bound to a new simple selector case statement and for each complex -- expression. We do this only for representable types, to prevent loops with -- inlinenonrep. -casesimpl expr@(Case scrut bndr ty alts) | not bndr_used = do +casesimpl c expr@(Case scrut bndr ty alts) | not bndr_used = do (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts let bindings = concat bindingss -- Replace the case with a let with bindings and a case @@ -611,7 +611,7 @@ casesimpl expr@(Case scrut bndr ty alts) | not bndr_used = do -- Don't simplify anything else return (Nothing, expr) -- Leave all other expressions unchanged -casesimpl expr = return expr +casesimpl c expr = return expr -- Perform this transform everywhere casesimpltop = everywhere ("casesimpl", casesimpl) @@ -622,11 +622,11 @@ casesimpltop = everywhere ("casesimpl", casesimpl) -- binders. caseremove, caseremovetop :: Transform -- Replace a useless case by the value of its single alternative -caseremove (Case scrut b ty [(con, bndrs, expr)]) | not usesvars = change expr +caseremove c (Case scrut b ty [(con, bndrs, expr)]) | not usesvars = change expr -- Find if any of the binders are used by expr where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` b:bndrs))) expr -- Leave all other expressions unchanged -caseremove expr = return expr +caseremove c expr = return expr -- Perform this transform everywhere caseremovetop = everywhere ("caseremove", caseremove) @@ -637,7 +637,7 @@ caseremovetop = everywhere ("caseremove", caseremove) appsimpl, appsimpltop :: Transform -- Simplify all representable arguments. Do this by introducing a new Let -- that binds the argument and passing the new binder in the application. -appsimpl expr@(App f arg) = do +appsimpl c expr@(App f arg) = do -- Check runtime representability repr <- isRepr arg local_var <- Trans.lift $ is_local_var arg @@ -648,7 +648,7 @@ appsimpl expr@(App f arg) = do else -- Leave non-representable arguments unchanged return expr -- Leave all other expressions unchanged -appsimpl expr = return expr +appsimpl c expr = return expr -- Perform this transform everywhere appsimpltop = everywhere ("appsimpl", appsimpl) @@ -662,7 +662,7 @@ argprop, argproptop :: Transform -- Transform any application of a named function (i.e., skip applications of -- lambda's). Also skip applications that have arguments with free type -- variables, since we can't inline those. -argprop expr@(App _ _) | is_var fexpr = do +argprop c expr@(App _ _) | is_var fexpr = do -- Find the body of the function called body_maybe <- Trans.lift $ getGlobalBind f case body_maybe of @@ -744,7 +744,7 @@ argprop expr@(App _ _) | is_var fexpr = do -- to a new id and just pass that new id to the old function body. return ([arg], [id], mkReferenceTo id) -- Leave all other expressions unchanged -argprop expr = return expr +argprop c expr = return expr -- Perform this transform everywhere argproptop = everywhere ("argprop", argprop) @@ -757,7 +757,7 @@ argproptop = everywhere ("argprop", argprop) -- apply map to a lambda expression This will not conflict with inlinenonrep, -- since that only inlines local let bindings, not top level bindings. funextract, funextracttop :: Transform -funextract expr@(App _ _) | is_var fexpr = do +funextract c expr@(App _ _) | is_var fexpr = do body_maybe <- Trans.lift $ getGlobalBind f case body_maybe of -- We don't have a function body for f, so we can perform this transform. @@ -795,7 +795,7 @@ funextract expr@(App _ _) | is_var fexpr = do doarg arg = return arg -- Leave all other expressions unchanged -funextract expr = return expr +funextract c expr = return expr -- Perform this transform everywhere funextracttop = everywhere ("funextract", funextract) @@ -805,9 +805,9 @@ funextracttop = everywhere ("funextract", funextract) -- solution, we should probably integrate this pass with lambdasimpl and -- letsimpl instead. -------------------------------- -simplrestop expr@(Lam _ _) = return expr -simplrestop expr@(Let _ _) = return expr -simplrestop expr = do +simplrestop c expr@(Lam _ _) = return expr +simplrestop c expr@(Let _ _) = return expr +simplrestop c expr = do local_var <- Trans.lift $ is_local_var expr -- Don't extract values that are not representable, to prevent loops with -- inlinenonrep diff --git "a/c\316\273ash/CLasH/Normalize/NormalizeTools.hs" "b/c\316\273ash/CLasH/Normalize/NormalizeTools.hs" index b9f4544..b1ca369 100644 --- "a/c\316\273ash/CLasH/Normalize/NormalizeTools.hs" +++ "b/c\316\273ash/CLasH/Normalize/NormalizeTools.hs" @@ -37,19 +37,18 @@ everywhere trans = applyboth (subeverywhere (everywhere trans)) trans -- Apply the first transformation, followed by the second transformation, and -- keep applying both for as long as expression still changes. applyboth :: Transform -> (String, Transform) -> Transform -applyboth first (name, second) expr = do +applyboth first (name, second) context expr = do -- Apply the first - expr' <- first expr + expr' <- first context expr -- Apply the second - (expr'', changed) <- Writer.listen $ second expr' + (expr'', changed) <- Writer.listen $ second context expr' if Monoid.getAny $ -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") changed then -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n" -- ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $ - applyboth first (name, second) - expr'' + applyboth first (name, second) context expr'' else -- trace ("No changes") $ return expr'' @@ -57,49 +56,49 @@ applyboth first (name, second) expr = do -- Apply the given transformation to all direct subexpressions (only), not the -- expression itself. subeverywhere :: Transform -> Transform -subeverywhere trans (App a b) = do - a' <- trans a - b' <- trans b +subeverywhere trans c (App a b) = do + a' <- trans (Other:c) a + b' <- trans (Other:c) b return $ App a' b' -subeverywhere trans (Let (NonRec b bexpr) expr) = do - bexpr' <- trans bexpr - expr' <- trans expr +subeverywhere trans c (Let (NonRec b bexpr) expr) = do + bexpr' <- trans (Other:c) bexpr + expr' <- trans (Other:c) expr return $ Let (NonRec b bexpr') expr' -subeverywhere trans (Let (Rec binds) expr) = do - expr' <- trans expr +subeverywhere trans c (Let (Rec binds) expr) = do + expr' <- trans (Other:c) expr binds' <- mapM transbind binds return $ Let (Rec binds') expr' where transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr) transbind (b, e) = do - e' <- trans e + e' <- trans (Other:c) e return (b, e') -subeverywhere trans (Lam x expr) = do - expr' <- trans expr +subeverywhere trans c (Lam x expr) = do + expr' <- trans (Other:c) expr return $ Lam x expr' -subeverywhere trans (Case scrut b t alts) = do - scrut' <- trans scrut +subeverywhere trans c (Case scrut b t alts) = do + scrut' <- trans (Other:c) scrut alts' <- mapM transalt alts return $ Case scrut' b t alts' where transalt :: CoreAlt -> TransformMonad CoreAlt transalt (con, binders, expr) = do - expr' <- trans expr + expr' <- trans (Other:c) expr return (con, binders, expr') -subeverywhere trans (Var x) = return $ Var x -subeverywhere trans (Lit x) = return $ Lit x -subeverywhere trans (Type x) = return $ Type x +subeverywhere trans c (Var x) = return $ Var x +subeverywhere trans c (Lit x) = return $ Lit x +subeverywhere trans c (Type x) = return $ Type x -subeverywhere trans (Cast expr ty) = do - expr' <- trans expr +subeverywhere trans c (Cast expr ty) = do + expr' <- trans (Other:c) expr return $ Cast expr' ty -subeverywhere trans expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr +subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr -- Apply the given transformation to all expressions, except for direct -- arguments of an application @@ -110,34 +109,34 @@ notappargs trans = applyboth (subnotappargs trans) trans -- (but not the expression itself), except for direct arguments of an -- application subnotappargs :: (String, Transform) -> Transform -subnotappargs trans (App a b) = do - a' <- subnotappargs trans a - b' <- subnotappargs trans b +subnotappargs trans c (App a b) = do + a' <- subnotappargs trans (Other:c) a + b' <- subnotappargs trans (Other:c) b return $ App a' b' -- Let subeverywhere handle all other expressions -subnotappargs trans expr = subeverywhere (notappargs trans) expr +subnotappargs trans c expr = subeverywhere (notappargs trans) c expr -- Runs each of the transforms repeatedly inside the State monad. dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr dotransforms transs expr = do - (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs + (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> trans [] e) expr transs if Monoid.getAny changed then dotransforms transs expr' else return expr' -- Inline all let bindings that satisfy the given condition inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform -inlinebind condition expr@(Let (NonRec bndr expr') res) = do +inlinebind condition context expr@(Let (NonRec bndr expr') res) = do applies <- condition (bndr, expr') if applies then do -- Substitute the binding in res and return that - res' <- substitute_clone bndr expr' res + res' <- substitute_clone bndr expr' context res change res' else -- Don't change this let return expr -- Leave all other expressions unchanged -inlinebind _ expr = return expr +inlinebind _ context expr = return expr -- Sets the changed flag in the TransformMonad, to signify that some -- transform has changed the result @@ -161,7 +160,7 @@ changeif False val = return val -- Does not set the changed flag. substitute :: CoreBndr -> CoreExpr -> Transform -- Use CoreSubst to subst a type var in an expression -substitute find repl expr = do +substitute find repl context expr = do let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl return $ CoreSubst.substExpr subst expr @@ -170,12 +169,12 @@ substitute find repl expr = do -- expression are cloned before the replacement, to guarantee uniqueness. substitute_clone :: CoreBndr -> CoreExpr -> Transform -- If we see the var to find, replace it by a uniqued version of repl -substitute_clone find repl (Var var) | find == var = do +substitute_clone find repl context (Var var) | find == var = do repl' <- Trans.lift $ CoreTools.genUniques repl change repl' -- For all other expressions, just look in subexpressions -substitute_clone find repl expr = subeverywhere (substitute_clone find repl) expr +substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr -- Is the given expression representable at runtime, based on the type? isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool diff --git "a/c\316\273ash/CLasH/Normalize/NormalizeTypes.hs" "b/c\316\273ash/CLasH/Normalize/NormalizeTypes.hs" index 3affc87..a7de6dc 100644 --- "a/c\316\273ash/CLasH/Normalize/NormalizeTypes.hs" +++ "b/c\316\273ash/CLasH/Normalize/NormalizeTypes.hs" @@ -14,5 +14,8 @@ import CLasH.Translator.TranslatorTypes -- over a single expression and track if the expression was changed. type TransformMonad = Writer.WriterT Monoid.Any TranslatorSession +-- | In what context does a core expression occur? +data CoreContext = Other -- ^ Another context + -- | Transforms a CoreExpr and keeps track if it has changed. -type Transform = CoreSyn.CoreExpr -> TransformMonad CoreSyn.CoreExpr +type Transform = [CoreContext] -> CoreSyn.CoreExpr -> TransformMonad CoreSyn.CoreExpr -- 2.30.2