Move the application of "everywhere" to dotransforms.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
index e1b8727086011bcc1a85094ca8059fb5bcc2e784..d9d4bd34e97df38f71d247d5d9ecd5707fef4665 100644 (file)
@@ -3,95 +3,33 @@
 -- This module provides functions for program transformations.
 --
 module CLasH.Normalize.NormalizeTools where
+
 -- Standard modules
-import Debug.Trace
-import qualified List
 import qualified Data.Monoid as Monoid
 import qualified Data.Either as Either
-import qualified Control.Arrow as Arrow
 import qualified Control.Monad as Monad
-import qualified Control.Monad.Trans.State as State
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
-import qualified Data.Map as Map
-import Data.Accessor
-import Data.Accessor.MonadState as MonadState
+import qualified Data.Accessor.Monad.Trans.State as MonadState
+-- import Debug.Trace
 
 -- GHC API
 import CoreSyn
-import qualified UniqSupply
-import qualified Unique
-import qualified OccName
 import qualified Name
-import qualified Var
-import qualified SrcLoc
-import qualified Type
-import qualified IdInfo
-import qualified CoreUtils
+import qualified Id
 import qualified CoreSubst
-import qualified VarSet
-import qualified HscTypes
-import Outputable ( showSDoc, ppr, nest )
+import qualified Type
+-- import qualified CoreUtils
+-- import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
-import CLasH.Utils.Pretty
-import CLasH.VHDL.VHDLTypes
+import CLasH.Translator.TranslatorTypes
+import CLasH.VHDL.Constants (builtinIds)
+import CLasH.Utils
+import qualified CLasH.Utils.Core.CoreTools as CoreTools
 import qualified CLasH.VHDL.VHDLTools as VHDLTools
 
--- Create a new internal var with the given name and type. A Unique is
--- appended to the given name, to ensure uniqueness (not strictly neccesary,
--- since the Unique is also stored in the name, but this ensures variable
--- names are unique in the output).
-mkInternalVar :: String -> Type.Type -> TransformMonad Var.Var
-mkInternalVar str ty = do
-  uniq <- mkUnique
-  let occname = OccName.mkVarOcc (str ++ show uniq)
-  let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
-  return $ Var.mkLocalVar IdInfo.VanillaId name ty IdInfo.vanillaIdInfo
-
--- Create a new type variable with the given name and kind. A Unique is
--- appended to the given name, to ensure uniqueness (not strictly neccesary,
--- since the Unique is also stored in the name, but this ensures variable
--- names are unique in the output).
-mkTypeVar :: String -> Type.Kind -> TransformMonad Var.Var
-mkTypeVar str kind = do
-  uniq <- mkUnique
-  let occname = OccName.mkVarOcc (str ++ show uniq)
-  let name = Name.mkInternalName uniq occname SrcLoc.noSrcSpan
-  return $ Var.mkTyVar name kind
-
--- Creates a binder for the given expression with the given name. This
--- works for both value and type level expressions, so it can return a Var or
--- TyVar (which is just an alias for Var).
-mkBinderFor :: CoreExpr -> String -> TransformMonad Var.Var
-mkBinderFor (Type ty) string = mkTypeVar string (Type.typeKind ty)
-mkBinderFor expr string = mkInternalVar string (CoreUtils.exprType expr)
-
--- Creates a reference to the given variable. This works for both a normal
--- variable as well as a type variable
-mkReferenceTo :: Var.Var -> CoreExpr
-mkReferenceTo var | Var.isTyVar var = (Type $ Type.mkTyVarTy var)
-                  | otherwise       = (Var var)
-
-cloneVar :: Var.Var -> TransformMonad Var.Var
-cloneVar v = do
-  uniq <- mkUnique
-  -- Swap out the unique, and reset the IdInfo (I'm not 100% sure what it
-  -- contains, but vannillaIdInfo is always correct, since it means "no info").
-  return $ Var.lazySetIdInfo (Var.setVarUnique v uniq) IdInfo.vanillaIdInfo
-
--- Creates a new function with the same name as the given binder (but with a
--- new unique) and with the given function body. Returns the new binder for
--- this function.
-mkFunction :: CoreBndr -> CoreExpr -> TransformMonad CoreBndr
-mkFunction bndr body = do
-  let ty = CoreUtils.exprType body
-  id <- cloneVar bndr
-  let newid = Var.setVarType id ty
-  Trans.lift $ addGlobalBind newid body
-  return newid
-
 -- Apply the given transformation to all expressions in the given expression,
 -- including the expression itself.
 everywhere :: (String, Transform) -> Transform
@@ -100,96 +38,81 @@ everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
 -- Apply the first transformation, followed by the second transformation, and
 -- keep applying both for as long as expression still changes.
 applyboth :: Transform -> (String, Transform) -> Transform
-applyboth first (name, second) expr  = do
+applyboth first (name, second) context expr = do
   -- Apply the first
-  expr' <- first expr
+  expr' <- first context expr
   -- Apply the second
-  (expr'', changed) <- Writer.listen $ second expr'
+  (expr'', changed) <- Writer.listen $ second context expr'
   if Monoid.getAny $
---        trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
+        -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n")
         changed 
-    then 
---      trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n") $
---      trace ("Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
-      applyboth first (name, second) $
-        expr'' 
+    then
+     -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
+     --        ++ "Context: " ++ show context ++ "\n"
+     --        ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
+      do
+        Trans.lift $ MonadState.modify tsTransformCounter (+1)
+        applyboth first (name, second) context expr'' 
     else 
---      trace ("No changes") $
+      -- trace ("No changes") $
       return expr''
 
 -- Apply the given transformation to all direct subexpressions (only), not the
 -- expression itself.
 subeverywhere :: Transform -> Transform
-subeverywhere trans (App a b) = do
-  a' <- trans a
-  b' <- trans b
+subeverywhere trans (App a b) = do
+  a' <- trans (AppFirst:c) a
+  b' <- trans (AppSecond:c) b
   return $ App a' b'
 
-subeverywhere trans (Let (NonRec b bexpr) expr) = do
-  bexpr' <- trans bexpr
-  expr' <- trans expr
+subeverywhere trans (Let (NonRec b bexpr) expr) = do
+  bexpr' <- trans (LetBinding:c) bexpr
+  expr' <- trans (LetBody:c) expr
   return $ Let (NonRec b bexpr') expr'
 
-subeverywhere trans (Let (Rec binds) expr) = do
-  expr' <- trans expr
+subeverywhere trans (Let (Rec binds) expr) = do
+  expr' <- trans (LetBody:c) expr
   binds' <- mapM transbind binds
   return $ Let (Rec binds') expr'
   where
     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
     transbind (b, e) = do
-      e' <- trans e
+      e' <- trans (LetBinding:c) e
       return (b, e')
 
-subeverywhere trans (Lam x expr) = do
-  expr' <- trans expr
+subeverywhere trans (Lam x expr) = do
+  expr' <- trans (LambdaBody:c) expr
   return $ Lam x expr'
 
-subeverywhere trans (Case scrut b t alts) = do
-  scrut' <- trans scrut
+subeverywhere trans (Case scrut b t alts) = do
+  scrut' <- trans (Other:c) scrut
   alts' <- mapM transalt alts
   return $ Case scrut' b t alts'
   where
     transalt :: CoreAlt -> TransformMonad CoreAlt
     transalt (con, binders, expr) = do
-      expr' <- trans expr
+      expr' <- trans (Other:c) expr
       return (con, binders, expr')
 
-subeverywhere trans (Var x) = return $ Var x
-subeverywhere trans (Lit x) = return $ Lit x
-subeverywhere trans (Type x) = return $ Type x
+subeverywhere trans (Var x) = return $ Var x
+subeverywhere trans (Lit x) = return $ Lit x
+subeverywhere trans (Type x) = return $ Type x
 
-subeverywhere trans (Cast expr ty) = do
-  expr' <- trans expr
+subeverywhere trans (Cast expr ty) = do
+  expr' <- trans (Other:c) expr
   return $ Cast expr' ty
 
-subeverywhere trans expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
-
--- Apply the given transformation to all expressions, except for direct
--- arguments of an application
-notappargs :: (String, Transform) -> Transform
-notappargs trans = applyboth (subnotappargs trans) trans
-
--- Apply the given transformation to all (direct and indirect) subexpressions
--- (but not the expression itself), except for direct arguments of an
--- application
-subnotappargs :: (String, Transform) -> Transform
-subnotappargs trans (App a b) = do
-  a' <- subnotappargs trans a
-  b' <- subnotappargs trans b
-  return $ App a' b'
-
--- Let subeverywhere handle all other expressions
-subnotappargs trans expr = subeverywhere (notappargs trans) expr
+subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
 
 -- Runs each of the transforms repeatedly inside the State monad.
-dotransforms :: [Transform] -> CoreExpr -> TransformSession CoreExpr
+dotransforms :: [(String, Transform)] -> CoreExpr -> TranslatorSession CoreExpr
 dotransforms transs expr = do
-  (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
+  (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> everywhere trans [] e) expr transs
   if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
-inlinebind condition expr@(Let (Rec binds) res) = do
+inlinebind condition context expr@(Let (Rec binds) res) = do
     -- Find all bindings that adhere to the condition
     res_eithers <- mapM docond binds
     case Either.partitionEithers res_eithers of
@@ -197,16 +120,37 @@ inlinebind condition expr@(Let (Rec binds) res) = do
       ([], _) -> return expr
       (replace, others) -> do
         -- Substitute the to be replaced binders with their expression
-        let newexpr = substitute replace (Let (Rec others) res)
+        newexpr <- do_substitute replace (Let (Rec others) res)
         change newexpr
   where 
+    -- Apply the condition to a let binding and return an Either
+    -- depending on whether it needs to be inlined or not.
     docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
     docond b = do
       res <- condition b
       return $ case res of True -> Left b; False -> Right b
 
+    -- Apply the given list of substitutions to the the given expression
+    do_substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> TransformMonad CoreExpr
+    do_substitute [] expr = return expr
+    do_substitute ((bndr, val):reps) expr = do
+      -- Perform this substitution in the expression
+      expr' <- substitute_clone bndr val context expr
+      -- And in the substitution values we will be using next
+      reps' <- mapM (subs_bind bndr val) reps
+      -- And then perform the remaining substitutions
+      do_substitute reps' expr'
+   
+    -- Replace the given binder with the given expression in the
+    -- expression oft the given let binding
+    subs_bind :: CoreBndr -> CoreExpr -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
+    subs_bind bndr expr (b, v) = do
+      v' <- substitute_clone  bndr expr (LetBinding:context) v
+      return (b, v')
+
+
 -- Leave all other expressions unchanged
-inlinebind _ expr = return expr
+inlinebind _ context expr = return expr
 
 -- Sets the changed flag in the TransformMonad, to signify that some
 -- transform has changed the result
@@ -219,48 +163,69 @@ change val = do
   setChanged
   return val
 
--- Create a new Unique
-mkUnique :: TransformMonad Unique.Unique
-mkUnique = Trans.lift $ do
-    us <- getA tsUniqSupply 
-    let (us', us'') = UniqSupply.splitUniqSupply us
-    putA tsUniqSupply us'
-    return $ UniqSupply.uniqFromSupply us''
-
--- Replace each of the binders given with the coresponding expressions in the
--- given expression.
-substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
-substitute [] expr = expr
--- Apply one substitution on the expression, but also on any remaining
--- substitutions. This seems to be the only way to handle substitutions like
--- [(b, c), (a, b)]. This means we reuse a substitution, which is not allowed
--- according to CoreSubst documentation (but it doesn't seem to be a problem).
--- TODO: Find out how this works, exactly.
-substitute ((b, e):subss) expr = substitute subss' expr'
-  where 
-    -- Create the Subst
-    subs = (CoreSubst.extendSubst CoreSubst.emptySubst b e)
-    -- Apply this substitution to the main expression
-    expr' = CoreSubst.substExpr subs expr
-    -- Apply this substitution on all the expressions in the remaining
-    -- substitutions
-    subss' = map (Arrow.second (CoreSubst.substExpr subs)) subss
-
--- Run a given TransformSession. Used mostly to setup the right calls and
--- an initial state.
-runTransformSession :: HscTypes.HscEnv -> UniqSupply.UniqSupply -> TransformSession a -> a
-runTransformSession env uniqSupply session = State.evalState session emptyTransformState
-  where
-    emptyTypeState = TypeState Map.empty [] Map.empty Map.empty env
-    emptyTransformState = TransformState uniqSupply Map.empty VarSet.emptyVarSet emptyTypeState
+-- Returns the given value and sets the changed flag if the bool given is
+-- True. Note that this will not unset the changed flag if the bool is False.
+changeif :: Bool -> a -> TransformMonad a
+changeif True val = change val
+changeif False val = return val
+
+-- | Creates a transformation that substitutes the given binder with the given
+-- expression (This can be a type variable, replace by a Type expression).
+-- Does not set the changed flag.
+substitute :: CoreBndr -> CoreExpr -> Transform
+-- Use CoreSubst to subst a type var in an expression
+substitute find repl context expr = do
+  let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
+  return $ CoreSubst.substExpr subst expr 
+
+-- | Creates a transformation that substitutes the given binder with the given
+-- expression. This does only work for value expressions! All binders in the
+-- expression are cloned before the replacement, to guarantee uniqueness.
+substitute_clone :: CoreBndr -> CoreExpr -> Transform
+-- If we see the var to find, replace it by a uniqued version of repl
+substitute_clone find repl context (Var var) | find == var = do
+  repl' <- Trans.lift $ CoreTools.genUniques repl
+  change repl'
+
+-- For all other expressions, just look in subexpressions
+substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr
 
 -- Is the given expression representable at runtime, based on the type?
-isRepr :: CoreSyn.CoreExpr -> TransformMonad Bool
-isRepr (Type ty) = return False
-isRepr expr = Trans.lift $ MonadState.lift tsType $ VHDLTools.isReprType (CoreUtils.exprType expr)
+isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
+isRepr tything = Trans.lift (isRepr' tything)
+
+isRepr' :: (CoreTools.TypedThing t) => t -> TranslatorSession Bool
+isRepr' tything = case CoreTools.getType tything of
+  Nothing -> return False
+  Just ty -> MonadState.lift tsType $ VHDLTools.isReprType ty 
 
-is_local_var :: CoreSyn.CoreExpr -> TransformSession Bool
+is_local_var :: CoreSyn.CoreExpr -> TranslatorSession Bool
 is_local_var (CoreSyn.Var v) = do
   bndrs <- getGlobalBinders
-  return $ not $ v `elem` bndrs
+  return $ v `notElem` bndrs
 is_local_var _ = return False
+
+-- Is the given binder defined by the user?
+isUserDefined :: CoreSyn.CoreBndr -> Bool
+-- System names are certain to not be user defined
+isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
+-- Builtin functions are usually not user-defined either (and would
+-- break currently if they are...)
+isUserDefined bndr = str `notElem` builtinIds
+  where
+    str = Name.getOccString bndr
+
+-- | Is the given binder normalizable? This means that its type signature can be
+-- represented in hardware, which should (?) guarantee that it can be made
+-- into hardware. This checks whether all the arguments and (optionally)
+-- the return value are
+-- representable.
+isNormalizeable :: 
+  Bool -- ^ Allow the result to be unrepresentable?
+  -> CoreBndr  -- ^ The binder to check
+  -> TranslatorSession Bool  -- ^ Is it normalizeable?
+isNormalizeable result_nonrep bndr = do
+  let ty = Id.idType bndr
+  let (arg_tys, res_ty) = Type.splitFunTys ty
+  let check_tys = if result_nonrep then arg_tys else (res_ty:arg_tys) 
+  andM $ mapM isRepr' check_tys