Move the application of "everywhere" to dotransforms.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
index a291ccc7b7ba9f25fcc971d728118e67626899e9..d9d4bd34e97df38f71d247d5d9ecd5707fef4665 100644 (file)
@@ -6,6 +6,7 @@ module CLasH.Normalize.NormalizeTools where
 
 -- Standard modules
 import qualified Data.Monoid as Monoid
 
 -- Standard modules
 import qualified Data.Monoid as Monoid
+import qualified Data.Either as Either
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
@@ -24,6 +25,7 @@ import qualified Type
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
 import CLasH.Translator.TranslatorTypes
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
 import CLasH.Translator.TranslatorTypes
+import CLasH.VHDL.Constants (builtinIds)
 import CLasH.Utils
 import qualified CLasH.Utils.Core.CoreTools as CoreTools
 import qualified CLasH.VHDL.VHDLTools as VHDLTools
 import CLasH.Utils
 import qualified CLasH.Utils.Core.CoreTools as CoreTools
 import qualified CLasH.VHDL.VHDLTools as VHDLTools
@@ -36,19 +38,21 @@ everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
 -- Apply the first transformation, followed by the second transformation, and
 -- keep applying both for as long as expression still changes.
 applyboth :: Transform -> (String, Transform) -> Transform
 -- Apply the first transformation, followed by the second transformation, and
 -- keep applying both for as long as expression still changes.
 applyboth :: Transform -> (String, Transform) -> Transform
-applyboth first (name, second) expr = do
+applyboth first (name, second) context expr = do
   -- Apply the first
   -- Apply the first
-  expr' <- first expr
+  expr' <- first context expr
   -- Apply the second
   -- Apply the second
-  (expr'', changed) <- Writer.listen $ second expr'
+  (expr'', changed) <- Writer.listen $ second context expr'
   if Monoid.getAny $
         -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n")
         changed 
   if Monoid.getAny $
         -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n")
         changed 
-    then 
+    then
      -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
      -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
+     --        ++ "Context: " ++ show context ++ "\n"
      --        ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
      --        ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
-      applyboth first (name, second)
-        expr'' 
+      do
+        Trans.lift $ MonadState.modify tsTransformCounter (+1)
+        applyboth first (name, second) context expr'' 
     else 
       -- trace ("No changes") $
       return expr''
     else 
       -- trace ("No changes") $
       return expr''
@@ -56,87 +60,97 @@ applyboth first (name, second) expr = do
 -- Apply the given transformation to all direct subexpressions (only), not the
 -- expression itself.
 subeverywhere :: Transform -> Transform
 -- Apply the given transformation to all direct subexpressions (only), not the
 -- expression itself.
 subeverywhere :: Transform -> Transform
-subeverywhere trans (App a b) = do
-  a' <- trans a
-  b' <- trans b
+subeverywhere trans (App a b) = do
+  a' <- trans (AppFirst:c) a
+  b' <- trans (AppSecond:c) b
   return $ App a' b'
 
   return $ App a' b'
 
-subeverywhere trans (Let (NonRec b bexpr) expr) = do
-  bexpr' <- trans bexpr
-  expr' <- trans expr
+subeverywhere trans (Let (NonRec b bexpr) expr) = do
+  bexpr' <- trans (LetBinding:c) bexpr
+  expr' <- trans (LetBody:c) expr
   return $ Let (NonRec b bexpr') expr'
 
   return $ Let (NonRec b bexpr') expr'
 
-subeverywhere trans (Let (Rec binds) expr) = do
-  expr' <- trans expr
+subeverywhere trans (Let (Rec binds) expr) = do
+  expr' <- trans (LetBody:c) expr
   binds' <- mapM transbind binds
   return $ Let (Rec binds') expr'
   where
     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
     transbind (b, e) = do
   binds' <- mapM transbind binds
   return $ Let (Rec binds') expr'
   where
     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
     transbind (b, e) = do
-      e' <- trans e
+      e' <- trans (LetBinding:c) e
       return (b, e')
 
       return (b, e')
 
-subeverywhere trans (Lam x expr) = do
-  expr' <- trans expr
+subeverywhere trans (Lam x expr) = do
+  expr' <- trans (LambdaBody:c) expr
   return $ Lam x expr'
 
   return $ Lam x expr'
 
-subeverywhere trans (Case scrut b t alts) = do
-  scrut' <- trans scrut
+subeverywhere trans (Case scrut b t alts) = do
+  scrut' <- trans (Other:c) scrut
   alts' <- mapM transalt alts
   return $ Case scrut' b t alts'
   where
     transalt :: CoreAlt -> TransformMonad CoreAlt
     transalt (con, binders, expr) = do
   alts' <- mapM transalt alts
   return $ Case scrut' b t alts'
   where
     transalt :: CoreAlt -> TransformMonad CoreAlt
     transalt (con, binders, expr) = do
-      expr' <- trans expr
+      expr' <- trans (Other:c) expr
       return (con, binders, expr')
 
       return (con, binders, expr')
 
-subeverywhere trans (Var x) = return $ Var x
-subeverywhere trans (Lit x) = return $ Lit x
-subeverywhere trans (Type x) = return $ Type x
+subeverywhere trans (Var x) = return $ Var x
+subeverywhere trans (Lit x) = return $ Lit x
+subeverywhere trans (Type x) = return $ Type x
 
 
-subeverywhere trans (Cast expr ty) = do
-  expr' <- trans expr
+subeverywhere trans (Cast expr ty) = do
+  expr' <- trans (Other:c) expr
   return $ Cast expr' ty
 
   return $ Cast expr' ty
 
-subeverywhere trans expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
-
--- Apply the given transformation to all expressions, except for direct
--- arguments of an application
-notappargs :: (String, Transform) -> Transform
-notappargs trans = applyboth (subnotappargs trans) trans
-
--- Apply the given transformation to all (direct and indirect) subexpressions
--- (but not the expression itself), except for direct arguments of an
--- application
-subnotappargs :: (String, Transform) -> Transform
-subnotappargs trans (App a b) = do
-  a' <- subnotappargs trans a
-  b' <- subnotappargs trans b
-  return $ App a' b'
-
--- Let subeverywhere handle all other expressions
-subnotappargs trans expr = subeverywhere (notappargs trans) expr
+subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
 
 -- Runs each of the transforms repeatedly inside the State monad.
 
 -- Runs each of the transforms repeatedly inside the State monad.
-dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
+dotransforms :: [(String, Transform)] -> CoreExpr -> TranslatorSession CoreExpr
 dotransforms transs expr = do
 dotransforms transs expr = do
-  (expr', changed) <- Writer.runWriterT $ Monad.foldM (flip ($)) expr transs
+  (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> everywhere trans [] e) expr transs
   if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
   if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
-inlinebind condition expr@(Let (NonRec bndr expr') res) = do
-    applies <- condition (bndr, expr')
-    if applies
-      then do
-        -- Substitute the binding in res and return that
-        res' <- substitute_clone bndr expr' res
-        change res'
-      else
-        -- Don't change this let
-        return expr
+inlinebind condition context expr@(Let (Rec binds) res) = do
+    -- Find all bindings that adhere to the condition
+    res_eithers <- mapM docond binds
+    case Either.partitionEithers res_eithers of
+      -- No replaces? No change
+      ([], _) -> return expr
+      (replace, others) -> do
+        -- Substitute the to be replaced binders with their expression
+        newexpr <- do_substitute replace (Let (Rec others) res)
+        change newexpr
+  where 
+    -- Apply the condition to a let binding and return an Either
+    -- depending on whether it needs to be inlined or not.
+    docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
+    docond b = do
+      res <- condition b
+      return $ case res of True -> Left b; False -> Right b
+
+    -- Apply the given list of substitutions to the the given expression
+    do_substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> TransformMonad CoreExpr
+    do_substitute [] expr = return expr
+    do_substitute ((bndr, val):reps) expr = do
+      -- Perform this substitution in the expression
+      expr' <- substitute_clone bndr val context expr
+      -- And in the substitution values we will be using next
+      reps' <- mapM (subs_bind bndr val) reps
+      -- And then perform the remaining substitutions
+      do_substitute reps' expr'
+   
+    -- Replace the given binder with the given expression in the
+    -- expression oft the given let binding
+    subs_bind :: CoreBndr -> CoreExpr -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
+    subs_bind bndr expr (b, v) = do
+      v' <- substitute_clone  bndr expr (LetBinding:context) v
+      return (b, v')
+
+
 -- Leave all other expressions unchanged
 -- Leave all other expressions unchanged
-inlinebind _ expr = return expr
+inlinebind _ context expr = return expr
 
 -- Sets the changed flag in the TransformMonad, to signify that some
 -- transform has changed the result
 
 -- Sets the changed flag in the TransformMonad, to signify that some
 -- transform has changed the result
@@ -160,7 +174,7 @@ changeif False val = return val
 -- Does not set the changed flag.
 substitute :: CoreBndr -> CoreExpr -> Transform
 -- Use CoreSubst to subst a type var in an expression
 -- Does not set the changed flag.
 substitute :: CoreBndr -> CoreExpr -> Transform
 -- Use CoreSubst to subst a type var in an expression
-substitute find repl expr = do
+substitute find repl context expr = do
   let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
   return $ CoreSubst.substExpr subst expr 
 
   let subst = CoreSubst.extendSubst CoreSubst.emptySubst find repl
   return $ CoreSubst.substExpr subst expr 
 
@@ -169,12 +183,12 @@ substitute find repl expr = do
 -- expression are cloned before the replacement, to guarantee uniqueness.
 substitute_clone :: CoreBndr -> CoreExpr -> Transform
 -- If we see the var to find, replace it by a uniqued version of repl
 -- expression are cloned before the replacement, to guarantee uniqueness.
 substitute_clone :: CoreBndr -> CoreExpr -> Transform
 -- If we see the var to find, replace it by a uniqued version of repl
-substitute_clone find repl (Var var) | find == var = do
+substitute_clone find repl context (Var var) | find == var = do
   repl' <- Trans.lift $ CoreTools.genUniques repl
   change repl'
 
 -- For all other expressions, just look in subexpressions
   repl' <- Trans.lift $ CoreTools.genUniques repl
   change repl'
 
 -- For all other expressions, just look in subexpressions
-substitute_clone find repl expr = subeverywhere (substitute_clone find repl) expr
+substitute_clone find repl context expr = subeverywhere (substitute_clone find repl) context expr
 
 -- Is the given expression representable at runtime, based on the type?
 isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
 
 -- Is the given expression representable at runtime, based on the type?
 isRepr :: (CoreTools.TypedThing t) => t -> TransformMonad Bool
@@ -195,26 +209,23 @@ is_local_var _ = return False
 isUserDefined :: CoreSyn.CoreBndr -> Bool
 -- System names are certain to not be user defined
 isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
 isUserDefined :: CoreSyn.CoreBndr -> Bool
 -- System names are certain to not be user defined
 isUserDefined bndr | Name.isSystemName (Id.idName bndr) = False
--- Check a list of typical compiler-defined names
-isUserDefined bndr = str `notElem` compiler_names
+-- Builtin functions are usually not user-defined either (and would
+-- break currently if they are...)
+isUserDefined bndr = str `notElem` builtinIds
   where
     str = Name.getOccString bndr
   where
     str = Name.getOccString bndr
-    -- These are names of bindings usually generated by the compiler. For some
-    -- reason these are not marked as system, probably because the name itself
-    -- is not made up by the compiler, just this particular binding is.
-    compiler_names = ["fromInteger", "head", "tail", "init", "last", "+", "*", "-", "!"]
 
 
--- Is the given binder normalizable? This means that its type signature can be
+-- Is the given binder normalizable? This means that its type signature can be
 -- represented in hardware, which should (?) guarantee that it can be made
 -- represented in hardware, which should (?) guarantee that it can be made
--- into hardware. Note that if a binder is not normalizable, it might become
--- so using argument propagation.
-isNormalizeable :: CoreBndr -> TransformMonad Bool 
-isNormalizeable bndr = Trans.lift (isNormalizeable' bndr)
-
-isNormalizeable' :: CoreBndr -> TranslatorSession Bool 
-isNormalizeable' bndr = do
+-- into hardware. This checks whether all the arguments and (optionally)
+-- the return value are
+-- representable.
+isNormalizeable :: 
+  Bool -- ^ Allow the result to be unrepresentable?
+  -> CoreBndr  -- ^ The binder to check
+  -> TranslatorSession Bool  -- ^ Is it normalizeable?
+isNormalizeable result_nonrep bndr = do
   let ty = Id.idType bndr
   let (arg_tys, res_ty) = Type.splitFunTys ty
   let ty = Id.idType bndr
   let (arg_tys, res_ty) = Type.splitFunTys ty
-  -- This function is normalizable if all its arguments and return value are
-  -- representable.
-  andM $ mapM isRepr' (res_ty:arg_tys)
+  let check_tys = if result_nonrep then arg_tys else (res_ty:arg_tys) 
+  andM $ mapM isRepr' check_tys