---
--- This module provides functions for program transformations.
---
-module CLasH.Normalize.NormalizeTools where
-
--- Standard modules
-import qualified Data.Monoid as Monoid
-import qualified Data.Either as Either
-import qualified Control.Monad as Monad
-import qualified Control.Monad.Trans.Writer as Writer
-import qualified Control.Monad.Trans.Class as Trans
-import qualified Data.Accessor.Monad.Trans.State as MonadState
-
--- GHC API
-import CoreSyn
-import qualified Name
-import qualified Id
-import qualified CoreSubst
-import qualified Type
-import qualified CoreUtils
-import Outputable ( showSDoc, ppr, nest )
-
--- Local imports
-import CLasH.Normalize.NormalizeTypes
-import CLasH.Translator.TranslatorTypes
-import CLasH.VHDL.Constants (builtinIds)
-import CLasH.Utils
-import qualified CLasH.Utils.Core.CoreTools as CoreTools
-import qualified CLasH.VHDL.VHDLTools as VHDLTools
-
--- Apply the given transformation to all expressions in the given expression,
--- including the expression itself.
-everywhere :: Transform -> Transform
-everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
-
-data NormDbgLevel =
- NormDbgNone -- ^ No debugging
- | NormDbgFinal -- ^ Print functions before / after normalization
- | NormDbgApplied -- ^ Print expressions before / after applying transformations
- | NormDbgAll -- ^ Print expressions when a transformation does not apply
- deriving (Eq, Ord)
-normalize_debug = NormDbgFinal
-
--- Applies a transform, optionally showing some debug output.
-apply :: (String, Transform) -> Transform
-apply (name, trans) ctx expr = do
- -- Apply the transformation and find out if it changed anything
- (expr', any_changed) <- Writer.listen $ trans ctx expr
- let changed = Monoid.getAny any_changed
- -- If it changed, increase the transformation counter
- Monad.when changed $ Trans.lift (MonadState.modify tsTransformCounter (+1))
- -- Prepare some debug strings
- let before = showSDoc (nest 4 $ ppr expr) ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr) ++ "\n"
- let context = "Context: " ++ show ctx ++ "\n"
- let after = showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
- traceIf (normalize_debug >= NormDbgApplied && changed) ("Changes when applying transform " ++ name ++ " to:\n" ++ before ++ context ++ "Result:\n" ++ after) $
- traceIf (normalize_debug >= NormDbgAll && not changed) ("No changes when applying transform " ++ name ++ " to:\n" ++ before ++ context) $
- return expr'
-
--- Apply the first transformation, followed by the second transformation, and
--- keep applying both for as long as expression still changes.
-applyboth :: Transform -> Transform -> Transform
-applyboth first second context expr = do
- -- Apply the first
- expr' <- first context expr
- -- Apply the second
- (expr'', changed) <- Writer.listen $ second context expr'
- if Monoid.getAny $ changed
- then
- applyboth first second context expr''
- else
- return expr''
-
--- Apply the given transformation to all direct subexpressions (only), not the
--- expression itself.
-subeverywhere :: Transform -> Transform
-subeverywhere trans c (App a b) = do
- a' <- trans (AppFirst:c) a
- b' <- trans (AppSecond:c) b
- return $ App a' b'
-
-subeverywhere trans c (Let (NonRec b bexpr) expr) = do
- bexpr' <- trans (LetBinding:c) bexpr
- expr' <- trans (LetBody:c) expr
- return $ Let (NonRec b bexpr') expr'
-
-subeverywhere trans c (Let (Rec binds) expr) = do
- expr' <- trans (LetBody:c) expr
- binds' <- mapM transbind binds
- return $ Let (Rec binds') expr'
- where
- transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
- transbind (b, e) = do
- e' <- trans (LetBinding:c) e
- return (b, e')
-
-subeverywhere trans c (Lam x expr) = do
- expr' <- trans (LambdaBody:c) expr
- return $ Lam x expr'
-
-subeverywhere trans c (Case scrut b t alts) = do
- scrut' <- trans (Other:c) scrut
- alts' <- mapM transalt alts
- return $ Case scrut' b t alts'
- where
- transalt :: CoreAlt -> TransformMonad CoreAlt
- transalt (con, binders, expr) = do
- expr' <- trans (Other:c) expr
- return (con, binders, expr')
-
-subeverywhere trans c (Var x) = return $ Var x
-subeverywhere trans c (Lit x) = return $ Lit x
-subeverywhere trans c (Type x) = return $ Type x
-
-subeverywhere trans c (Cast expr ty) = do
- expr' <- trans (Other:c) expr
- return $ Cast expr' ty
-
-subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
-
--- Runs each of the transforms repeatedly inside the State monad.
-dotransforms :: [(String, Transform)] -> CoreExpr -> TranslatorSession CoreExpr
-dotransforms transs expr = do
- (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> everywhere (apply trans) [] e) expr transs
- if Monoid.getAny changed then dotransforms transs expr' else return expr'
-
--- Inline all let bindings that satisfy the given condition
-inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
-inlinebind condition context expr@(Let (Rec binds) res) = do
- -- Find all bindings that adhere to the condition
- res_eithers <- mapM docond binds
- case Either.partitionEithers res_eithers of
- -- No replaces? No change
- ([], _) -> return expr
- (replace, others) -> do
- -- Substitute the to be replaced binders with their expression
- newexpr <- do_substitute replace (Let (Rec others) res)
- change newexpr
- where
- -- Apply the condition to a let binding and return an Either
- -- depending on whether it needs to be inlined or not.
- docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
- docond b = do
- res <- condition b
- return $ case res of True -> Left b; False -> Right b
-
- -- Apply the given list of substitutions to the the given expression
- do_substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> TransformMonad CoreExpr
- do_substitute [] expr = return expr
- do_substitute ((bndr, val):reps) expr = do
- -- Perform this substitution in the expression
- expr' <- substitute_clone bndr val context expr
- -- And in the substitution values we will be using next
- reps' <- mapM (subs_bind bndr val) reps
- -- And then perform the remaining substitutions
- do_substitute reps' expr'
-
- -- Replace the given binder with the given expression in the
- -- expression oft the given let binding
- subs_bind :: CoreBndr -> CoreExpr -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
- subs_bind bndr expr (b, v) = do
- v' <- substitute_clone bndr expr (LetBinding:context) v
- return (b, v')
-
-
--- Leave all other expressions unchanged
-inlinebind _ context expr = return expr
-
--- Sets the changed flag in the TransformMonad, to signify that some
--- transform has changed the result
-setChanged :: TransformMonad ()
-setChanged = Writer.tell (Monoid.Any True)
-
--- Sets the changed flag and returns the given value.
-change :: a -> TransformMonad a
-change val = do
- setChanged
- return val
-
--- Returns the given value and sets the changed flag if the bool given is
--- True. Note that this will not unset the changed flag if the bool is False.
-changeif :: Bool -> a -> TransformMonad a
-changeif True val = change val
-changeif False val = return val
-
--- | Creates a transformation that substitutes the given binder with the given
--- expression (This can be a type variable, replace by a Type expression).
--- Does not set the changed flag.
-substitute :: CoreBndr -> CoreExpr -> Transform
--- Use CoreSubst to subst a type var in an expression
-substitute find repl context expr = do
- let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
- return $ CoreSubst.substExpr subst expr
-
--- | Creates a transformation that substitutes the given binder with the given
--- expression. This does only work for value expressions! All binders in the
--- expression are cloned before the replacement, to guarantee uniqueness.
-substitute_clone :: CoreBndr -> CoreExpr -> Transform
--- If we see the var to find, replace it by a uniqued version of repl
-substitute_clone find repl context (Var var) | find == var = do
- repl' <- Trans.lift $ CoreTools.genUniques repl
- change repl'
-
--- For all other expressions, just look in subexpressions
-substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr
-
--- Is the given expression representable at runtime, based on the type?
-isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
-isRepr tything = Trans.lift (isRepr' tything)
-
-isRepr' :: (CoreTools.TypedThing t) => t -> TranslatorSession Bool
-isRepr' tything = case CoreTools.getType tything of
- Nothing -> return False
- Just ty -> MonadState.lift tsType $ VHDLTools.isReprType ty
-
-is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
-is_local_var (CoreSyn.Var v) = do
- bndrs <- getGlobalBinders
- return $ v `notElem` bndrs
-is_local_var _ = return False
-
--- Is the given binder defined by the user?
-isUserDefined :: CoreSyn.CoreBndr -> Bool
--- System names are certain to not be user defined
-isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
--- Builtin functions are usually not user-defined either (and would
--- break currently if they are...)
-isUserDefined bndr = str `notElem` builtinIds
- where
- str = Name.getOccString bndr
-
--- | Is the given binder normalizable? This means that its type signature can be
--- represented in hardware, which should (?) guarantee that it can be made
--- into hardware. This checks whether all the arguments and (optionally)
--- the return value are
--- representable.
-isNormalizeable ::
- Bool -- ^ Allow the result to be unrepresentable?
- -> CoreBndr -- ^ The binder to check
- -> TranslatorSession Bool -- ^ Is it normalizeable?
-isNormalizeable result_nonrep bndr = do
- let ty = Id.idType bndr
- let (arg_tys, res_ty) = Type.splitFunTys ty
- let check_tys = if result_nonrep then arg_tys else (res_ty:arg_tys)
- andM $ mapM isRepr' check_tys