-- top level function "normalize", and defines the actual transformation passes that
-- are performed.
--
-module CLasH.Normalize (normalizeModule) where
+module CLasH.Normalize (getNormalized, normalizeExpr) where
-- Standard modules
import Debug.Trace
-- Local imports
import CLasH.Normalize.NormalizeTypes
+import CLasH.Translator.TranslatorTypes
import CLasH.Normalize.NormalizeTools
import CLasH.VHDL.VHDLTypes
+import qualified CLasH.Utils as Utils
import CLasH.Utils.Core.CoreTools
+import CLasH.Utils.Core.BinderTools
import CLasH.Utils.Pretty
--------------------------------
eta, etatop :: Transform
eta expr | is_fun expr && not (is_lam expr) = do
let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
- id <- mkInternalVar "param" arg_ty
+ id <- Trans.lift $ mkInternalVar "param" arg_ty
change (Lam id (App expr (Var id)))
-- Leave all other expressions unchanged
eta e = return e
then do
-- If the result is not a local var already (to prevent loops with
-- ourselves), extract it.
- id <- mkInternalVar "foo" (CoreUtils.exprType res)
+ id <- Trans.lift $ mkInternalVar "foo" (CoreUtils.exprType res)
let bind = (id, res)
change $ Let (Rec (bind:binds)) (Var id)
else
letremovetop :: Transform
letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
+--------------------------------
+-- Unused let binding removal
+--------------------------------
+letremoveunused, letremoveunusedtop :: Transform
+letremoveunused expr@(Let (Rec binds) res) = do
+ -- Filter out all unused binds.
+ let binds' = filter dobind binds
+ -- Only set the changed flag if binds got removed
+ changeif (length binds' /= length binds) (Let (Rec binds') res)
+ where
+ bound_exprs = map snd binds
+ -- For each bind check if the bind is used by res or any of the bound
+ -- expressions
+ dobind (bndr, _) = not $ any (expr_uses_binders [bndr]) (res:bound_exprs)
+-- Leave all other expressions unchanged
+letremoveunused expr = return expr
+letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
+
--------------------------------
-- Function inlining
--------------------------------
repr <- isRepr scrut
if repr
then do
- id <- mkInternalVar "scrut" (CoreUtils.exprType scrut)
+ id <- Trans.lift $ mkInternalVar "scrut" (CoreUtils.exprType scrut)
change $ Let (Rec [(id, scrut)]) (Case (Var id) b ty alts)
else
return expr
-- Create on new binder that will actually capture a value in this
-- case statement, and return it.
let bty = (Id.idType b)
- id <- mkInternalVar "sel" bty
+ id <- Trans.lift $ mkInternalVar "sel" bty
let binders = take i wildbndrs ++ [id] ++ drop (i+1) wildbndrs
let caseexpr = Case scrut b bty [(con, binders, Var id)]
return (wildbndrs!!i, Just (b, caseexpr))
-- prevent loops with inlinenonrep).
if (not uses_bndrs) && (not local_var) && repr
then do
- id <- mkInternalVar "caseval" (CoreUtils.exprType expr)
+ id <- Trans.lift $ mkInternalVar "caseval" (CoreUtils.exprType expr)
-- We don't flag a change here, since casevalsimpl will do that above
-- based on Just we return here.
return $ (Just (id, expr), Var id)
local_var <- Trans.lift $ is_local_var arg
if repr && not local_var
then do -- Extract representable arguments
- id <- mkInternalVar "arg" (CoreUtils.exprType arg)
+ id <- Trans.lift $ mkInternalVar "arg" (CoreUtils.exprType arg)
change $ Let (Rec [(id, arg)]) (App f (Var id))
else -- Leave non-representable arguments unchanged
return expr
-- the old body applied to some arguments.
let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
-- Create a new function with the same name but a new body
- newf <- mkFunction f newbody
+ newf <- Trans.lift $ mkFunction f newbody
-- Replace the original application with one of the new function to the
-- new arguments.
change $ MkCore.mkCoreApps (Var newf) newargs
-- Representable types will not be propagated, and arguments with free
-- type variables will be propagated later.
-- TODO: preserve original naming?
- id <- mkBinderFor arg "param"
+ id <- Trans.lift $ mkBinderFor arg "param"
-- Just pass the original argument to the new function, which binds it
-- to a new id and just pass that new id to the old function body.
return ([arg], [id], mkReferenceTo id)
-- by the argument expression.
let free_vars = VarSet.varSetElems $ CoreFVs.exprFreeVars arg
let body = MkCore.mkCoreLams free_vars arg
- id <- mkBinderFor body "fun"
+ id <- Trans.lift $ mkBinderFor body "fun"
Trans.lift $ addGlobalBind id body
-- Replace the argument with a reference to the new function, applied to
-- all vars it uses.
-- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop]
-
--- Turns the given bind into VHDL
-normalizeModule ::
- HscTypes.HscEnv
- -> UniqSupply.UniqSupply -- ^ A UniqSupply we can use
- -> [(CoreBndr, CoreExpr)] -- ^ All bindings we know (i.e., in the current module)
- -> [CoreExpr]
- -> [CoreBndr] -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
- -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
- -> ([(CoreBndr, CoreExpr)], [(CoreBndr, CoreExpr)], TypeState) -- ^ The resulting VHDL
-
-normalizeModule env uniqsupply bindings testexprs generate_for statefuls = runTransformSession env uniqsupply $ do
- testbinds <- mapM (\x -> do { v <- mkBinderFor' x "test" ; return (v,x) } ) testexprs
- let testbinders = (map fst testbinds)
- -- Put all the bindings in this module in the tsBindings map
- putA tsBindings (Map.fromList (bindings ++ testbinds))
- -- (Recursively) normalize each of the requested bindings
- mapM normalizeBind (generate_for ++ testbinders)
- -- Get all initial bindings and the ones we produced
- bindings_map <- getA tsBindings
- let bindings = Map.assocs bindings_map
- normalized_binders' <- getA tsNormalized
- let normalized_binders = VarSet.delVarSetList normalized_binders' testbinders
- let ret_testbinds = zip testbinders (Maybe.catMaybes $ map (\x -> lookup x bindings) testbinders)
- let ret_binds = filter ((`VarSet.elemVarSet` normalized_binders) . fst) bindings
- typestate <- getA tsType
- -- But return only the normalized bindings
- return $ (ret_binds, ret_testbinds, typestate)
-
-normalizeBind :: CoreBndr -> TransformSession ()
-normalizeBind bndr =
- -- Don't normalize global variables, these should be either builtin
- -- functions or data constructors.
- Monad.when (Var.isLocalId bndr) $ do
- -- Skip binders that have a polymorphic type, since it's impossible to
- -- create polymorphic hardware.
- if is_poly (Var bndr)
- then
- -- This should really only happen at the top level... TODO: Give
- -- a different error if this happens down in the recursion.
- error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
- else do
- normalized_funcs <- getA tsNormalized
- -- See if this function was normalized already
- if VarSet.elemVarSet bndr normalized_funcs
- then
- -- Yup, don't do it again
- return ()
- else do
- -- Nope, note that it has been and do it.
- modA tsNormalized (flip VarSet.extendVarSet bndr)
- expr_maybe <- getGlobalBind bndr
- case expr_maybe of
- Just expr -> do
- -- Introduce an empty Let at the top level, so there will always be
- -- a let in the expression (none of the transformations will remove
- -- the last let).
- let expr' = Let (Rec []) expr
- -- Normalize this expression
- trace ("Transforming " ++ (show bndr) ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
- expr' <- dotransforms transforms expr'
- trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
- -- And store the normalized version in the session
- modA tsBindings (Map.insert bndr expr')
- -- Find all vars used with a function type. All of these should be global
- -- binders (i.e., functions used), since any local binders with a function
- -- type should have been inlined already.
- bndrs <- getGlobalBinders
- let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> not (Id.isDictId v) && v `elem` bndrs) expr'
- let used_funcs = VarSet.varSetElems used_funcs_set
- -- Process each of the used functions recursively
- mapM normalizeBind used_funcs
- return ()
- -- We don't have a value for this binder. This really shouldn't
- -- happen for local id's...
- Nothing -> error $ "\nNormalize.normalizeBind: No value found for binder " ++ pprString bndr ++ "? This should not happen!"
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop]
+
+-- | Returns the normalized version of the given function.
+getNormalized ::
+ CoreBndr -- ^ The function to get
+ -> TranslatorSession CoreExpr -- The normalized function body
+
+getNormalized bndr = Utils.makeCached bndr tsNormalized $ do
+ if is_poly (Var bndr)
+ then
+ -- This should really only happen at the top level... TODO: Give
+ -- a different error if this happens down in the recursion.
+ error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
+ else do
+ expr <- getBinding bndr
+ normalizeExpr (show bndr) expr
+
+-- | Normalize an expression
+normalizeExpr ::
+ String -- ^ What are we normalizing? For debug output only.
+ -> CoreSyn.CoreExpr -- ^ The expression to normalize
+ -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression
+
+normalizeExpr what expr = do
+ -- Introduce an empty Let at the top level, so there will always be
+ -- a let in the expression (none of the transformations will remove
+ -- the last let).
+ let expr' = Let (Rec []) expr
+ -- Normalize this expression
+ trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
+ expr'' <- dotransforms transforms expr'
+ trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+ return expr''
+
+-- | Get the value that is bound to the given binder at top level. Fails when
+-- there is no such binding.
+getBinding ::
+ CoreBndr -- ^ The binder to get the expression for
+ -> TranslatorSession CoreExpr -- ^ The value bound to the binder
+
+getBinding bndr = Utils.makeCached bndr tsBindings $ do
+ -- If the binding isn't in the "cache" (bindings map), then we can't create
+ -- it out of thin air, so return an error.
+ error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr