Make inlinebind work for non-recursive lets.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
index e1b8727086011bcc1a85094ca8059fb5bcc2e784..116c84742f45f7953b7fcea588820c8d7e305da3 100644 (file)
@@ -3,6 +3,7 @@
 -- This module provides functions for program transformations.
 --
 module CLasH.Normalize.NormalizeTools where
+
 -- Standard modules
 import Debug.Trace
 import qualified List
@@ -19,79 +20,17 @@ import Data.Accessor.MonadState as MonadState
 
 -- GHC API
 import CoreSyn
-import qualified UniqSupply
-import qualified Unique
-import qualified OccName
-import qualified Name
-import qualified Var
-import qualified SrcLoc
-import qualified Type
-import qualified IdInfo
-import qualified CoreUtils
 import qualified CoreSubst
-import qualified VarSet
-import qualified HscTypes
+import qualified CoreUtils
 import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
+import CLasH.Translator.TranslatorTypes
 import CLasH.Utils.Pretty
 import CLasH.VHDL.VHDLTypes
 import qualified CLasH.VHDL.VHDLTools as VHDLTools
 
--- Create a new internal var with the given name and type. A Unique is
--- appended to the given name, to ensure uniqueness (not strictly neccesary,
--- since the Unique is also stored in the name, but this ensures variable
--- names are unique in the output).
-mkInternalVar :: String -> Type.Type -> TransformMonad Var.Var
-mkInternalVar str ty = do
-  uniq <- mkUnique
-  let occname = OccName.mkVarOcc (str ++ show uniq)
-  let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
-  return $ Var.mkLocalVar IdInfo.VanillaId name ty IdInfo.vanillaIdInfo
-
--- Create a new type variable with the given name and kind. A Unique is
--- appended to the given name, to ensure uniqueness (not strictly neccesary,
--- since the Unique is also stored in the name, but this ensures variable
--- names are unique in the output).
-mkTypeVar :: String -> Type.Kind -> TransformMonad Var.Var
-mkTypeVar str kind = do
-  uniq <- mkUnique
-  let occname = OccName.mkVarOcc (str ++ show uniq)
-  let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
-  return $ Var.mkTyVar name kind
-
--- Creates a binder for the given expression with the given name. This
--- works for both value and type level expressions, so it can return a Var or
--- TyVar (which is just an alias for Var).
-mkBinderFor :: CoreExpr -> String -> TransformMonad Var.Var
-mkBinderFor (Type ty) string = mkTypeVar string (Type.typeKind ty)
-mkBinderFor expr string = mkInternalVar string (CoreUtils.exprType expr)
-
--- Creates a reference to the given variable. This works for both a normal
--- variable as well as a type variable
-mkReferenceTo :: Var.Var -> CoreExpr
-mkReferenceTo var | Var.isTyVar var = (Type $ Type.mkTyVarTy var)
-                  | otherwise       = (Var var)
-
-cloneVar :: Var.Var -> TransformMonad Var.Var
-cloneVar v = do
-  uniq <- mkUnique
-  -- Swap out the unique, and reset the IdInfo (I'm not 100% sure what it
-  -- contains, but vannillaIdInfo is always correct, since it means "no info").
-  return $ Var.lazySetIdInfo (Var.setVarUnique v uniq) IdInfo.vanillaIdInfo
-
--- Creates a new function with the same name as the given binder (but with a
--- new unique) and with the given function body. Returns the new binder for
--- this function.
-mkFunction :: CoreBndr -> CoreExpr -> TransformMonad CoreBndr
-mkFunction bndr body = do
-  let ty = CoreUtils.exprType body
-  id <- cloneVar bndr
-  let newid = Var.setVarType id ty
-  Trans.lift $ addGlobalBind newid body
-  return newid
-
 -- Apply the given transformation to all expressions in the given expression,
 -- including the expression itself.
 everywhere :: (String, Transform) -> Transform
@@ -182,29 +121,22 @@ subnotappargs trans (App a b) = do
 subnotappargs trans expr = subeverywhere (notappargs trans) expr
 
 -- Runs each of the transforms repeatedly inside the State monad.
-dotransforms :: [Transform] -> CoreExpr -> TransformSession CoreExpr
+dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
 dotransforms transs expr = do
   (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
   if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
-inlinebind condition expr@(Let (Rec binds) res) = do
-    -- Find all bindings that adhere to the condition
-    res_eithers <- mapM docond binds
-    case Either.partitionEithers res_eithers of
-      -- No replaces? No change
-      ([], _) -> return expr
-      (replace, others) -> do
-        -- Substitute the to be replaced binders with their expression
-        let newexpr = substitute replace (Let (Rec others) res)
-        change newexpr
-  where 
-    docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
-    docond b = do
-      res <- condition b
-      return $ case res of True -> Left b; False -> Right b
-
+inlinebind condition expr@(Let (NonRec bndr expr') res) = do
+    applies <- condition (bndr, expr')
+    if applies
+      then
+        -- Substitute the binding in res and return that
+        change $ substitute [(bndr, expr')] res
+      else
+        -- Don't change this let
+        return expr
 -- Leave all other expressions unchanged
 inlinebind _ expr = return expr
 
@@ -219,13 +151,11 @@ change val = do
   setChanged
   return val
 
--- Create a new Unique
-mkUnique :: TransformMonad Unique.Unique
-mkUnique = Trans.lift $ do
-    us <- getA tsUniqSupply 
-    let (us', us'') = UniqSupply.splitUniqSupply us
-    putA tsUniqSupply us'
-    return $ UniqSupply.uniqFromSupply us''
+-- Returns the given value and sets the changed flag if the bool given is
+-- True. Note that this will not unset the changed flag if the bool is False.
+changeif :: Bool -> a -> TransformMonad a
+changeif True val = change val
+changeif False val = return val
 
 -- Replace each of the binders given with the coresponding expressions in the
 -- given expression.
@@ -246,20 +176,12 @@ substitute ((b, e):subss) expr = substitute subss' expr'
     -- substitutions
     subss' = map (Arrow.second (CoreSubst.substExpr subs)) subss
 
--- Run a given TransformSession. Used mostly to setup the right calls and
--- an initial state.
-runTransformSession :: HscTypes.HscEnv -> UniqSupply.UniqSupply -> TransformSession a -> a
-runTransformSession env uniqSupply session = State.evalState session emptyTransformState
-  where
-    emptyTypeState = TypeState Map.empty [] Map.empty Map.empty env
-    emptyTransformState = TransformState uniqSupply Map.empty VarSet.emptyVarSet emptyTypeState
-
 -- Is the given expression representable at runtime, based on the type?
 isRepr :: CoreSyn.CoreExpr -> TransformMonad Bool
 isRepr (Type ty) = return False
 isRepr expr = Trans.lift $ MonadState.lift tsType $ VHDLTools.isReprType (CoreUtils.exprType expr)
 
-is_local_var :: CoreSyn.CoreExpr -> TransformSession Bool
+is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
 is_local_var (CoreSyn.Var v) = do
   bndrs <- getGlobalBinders
   return $ not $ v `elem` bndrs