Make debug output controllable with a top-level "constant".
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
index c774a335f34f01b65fd9f6b1c51c17f29655702c..f6c254e431381376700ed528a88cdcd9de9c118a 100644 (file)
@@ -6,11 +6,11 @@ module CLasH.Normalize.NormalizeTools where
 
 -- Standard modules
 import qualified Data.Monoid as Monoid
 
 -- Standard modules
 import qualified Data.Monoid as Monoid
+import qualified Data.Either as Either
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Data.Accessor.Monad.Trans.State as MonadState
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Data.Accessor.Monad.Trans.State as MonadState
--- import Debug.Trace
 
 -- GHC API
 import CoreSyn
 
 -- GHC API
 import CoreSyn
@@ -18,8 +18,8 @@ import qualified Name
 import qualified Id
 import qualified CoreSubst
 import qualified Type
 import qualified Id
 import qualified CoreSubst
 import qualified Type
--- import qualified CoreUtils
--- import Outputable ( showSDoc, ppr, nest )
+import qualified CoreUtils
+import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
@@ -31,26 +31,45 @@ import qualified CLasH.VHDL.VHDLTools as VHDLTools
 
 -- Apply the given transformation to all expressions in the given expression,
 -- including the expression itself.
 
 -- Apply the given transformation to all expressions in the given expression,
 -- including the expression itself.
-everywhere :: (String, Transform) -> Transform
+everywhere :: Transform -> Transform
 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
 
 everywhere trans = applyboth (subeverywhere (everywhere trans)) trans
 
+data NormDbgLevel = 
+    NormDbgNone         -- ^ No debugging
+  | NormDbgFinal        -- ^ Print functions before / after normalization
+  | NormDbgApplied      -- ^ Print expressions before / after applying transformations
+  | NormDbgAll          -- ^ Print expressions when a transformation does not apply
+  deriving (Eq, Ord)
+normalize_debug = NormDbgFinal
+
+-- Applies a transform, optionally showing some debug output.
+apply :: (String, Transform) -> Transform
+apply (name, trans) ctx expr =  do
+    -- Apply the transformation and find out if it changed anything
+    (expr', any_changed) <- Writer.listen $ trans ctx expr
+    let changed = Monoid.getAny any_changed
+    -- If it changed, increase the transformation counter 
+    Monad.when changed $ Trans.lift (MonadState.modify tsTransformCounter (+1))
+    -- Prepare some debug strings
+    let before = showSDoc (nest 4 $ ppr expr) ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr) ++ "\n"
+    let context = "Context: " ++ show ctx ++ "\n"
+    let after  = showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
+    traceIf (normalize_debug >= NormDbgApplied && changed) ("Changes when applying transform " ++ name ++ " to:\n" ++ before ++ context ++ "Result:\n" ++ after) $ 
+     traceIf (normalize_debug >= NormDbgAll && not changed) ("No changes when applying transform " ++ name ++ " to:\n" ++ before  ++ context) $
+     return expr'
+
 -- Apply the first transformation, followed by the second transformation, and
 -- keep applying both for as long as expression still changes.
 -- Apply the first transformation, followed by the second transformation, and
 -- keep applying both for as long as expression still changes.
-applyboth :: Transform -> (String, Transform) -> Transform
-applyboth first (name, second) context expr = do
+applyboth :: Transform -> Transform -> Transform
+applyboth first second context expr = do
   -- Apply the first
   expr' <- first context expr
   -- Apply the second
   (expr'', changed) <- Writer.listen $ second context expr'
   -- Apply the first
   expr' <- first context expr
   -- Apply the second
   (expr'', changed) <- Writer.listen $ second context expr'
-  if Monoid.getAny $
-        -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n")
-        changed 
-    then 
-     -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
-     --        ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
-      applyboth first (name, second) context expr'' 
+  if Monoid.getAny $ changed
+    then
+      applyboth first second context expr'' 
     else 
     else 
-      -- trace ("No changes") $
       return expr''
 
 -- Apply the given transformation to all direct subexpressions (only), not the
       return expr''
 
 -- Apply the given transformation to all direct subexpressions (only), not the
@@ -62,22 +81,22 @@ subeverywhere trans c (App a b) = do
   return $ App a' b'
 
 subeverywhere trans c (Let (NonRec b bexpr) expr) = do
   return $ App a' b'
 
 subeverywhere trans c (Let (NonRec b bexpr) expr) = do
-  bexpr' <- trans (Other:c) bexpr
-  expr' <- trans (Other:c) expr
+  bexpr' <- trans (LetBinding:c) bexpr
+  expr' <- trans (LetBody:c) expr
   return $ Let (NonRec b bexpr') expr'
 
 subeverywhere trans c (Let (Rec binds) expr) = do
   return $ Let (NonRec b bexpr') expr'
 
 subeverywhere trans c (Let (Rec binds) expr) = do
-  expr' <- trans (Other:c) expr
+  expr' <- trans (LetBody:c) expr
   binds' <- mapM transbind binds
   return $ Let (Rec binds') expr'
   where
     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
     transbind (b, e) = do
   binds' <- mapM transbind binds
   return $ Let (Rec binds') expr'
   where
     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
     transbind (b, e) = do
-      e' <- trans (Other:c) e
+      e' <- trans (LetBinding:c) e
       return (b, e')
 
 subeverywhere trans c (Lam x expr) = do
       return (b, e')
 
 subeverywhere trans c (Lam x expr) = do
-  expr' <- trans (Other:c) expr
+  expr' <- trans (LambdaBody:c) expr
   return $ Lam x expr'
 
 subeverywhere trans c (Case scrut b t alts) = do
   return $ Lam x expr'
 
 subeverywhere trans c (Case scrut b t alts) = do
@@ -100,41 +119,51 @@ subeverywhere trans c (Cast expr ty) = do
 
 subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
 
 
 subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
 
--- Apply the given transformation to all expressions, except for direct
--- arguments of an application
-notappargs :: (String, Transform) -> Transform
-notappargs trans = applyboth (subnotappargs trans) trans
-
--- Apply the given transformation to all (direct and indirect) subexpressions
--- (but not the expression itself), except for direct arguments of an
--- application
-subnotappargs :: (String, Transform) -> Transform
-subnotappargs trans c (App a b) = do
-  a' <- subnotappargs trans (Other:c) a
-  b' <- subnotappargs trans (Other:c) b
-  return $ App a' b'
-
--- Let subeverywhere handle all other expressions
-subnotappargs trans c expr = subeverywhere (notappargs trans) c expr
-
 -- Runs each of the transforms repeatedly inside the State monad.
 -- Runs each of the transforms repeatedly inside the State monad.
-dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
+dotransforms :: [(String, Transform)] -> CoreExpr -> TranslatorSession CoreExpr
 dotransforms transs expr = do
 dotransforms transs expr = do
-  (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> trans [] e) expr transs
+  (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> everywhere (apply trans) [] e) expr transs
   if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
   if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
-inlinebind condition context expr@(Let (NonRec bndr expr') res) = do
-    applies <- condition (bndr, expr')
-    if applies
-      then do
-        -- Substitute the binding in res and return that
-        res' <- substitute_clone bndr expr' context res
-        change res'
-      else
-        -- Don't change this let
-        return expr
+inlinebind condition context expr@(Let (Rec binds) res) = do
+    -- Find all bindings that adhere to the condition
+    res_eithers <- mapM docond binds
+    case Either.partitionEithers res_eithers of
+      -- No replaces? No change
+      ([], _) -> return expr
+      (replace, others) -> do
+        -- Substitute the to be replaced binders with their expression
+        newexpr <- do_substitute replace (Let (Rec others) res)
+        change newexpr
+  where 
+    -- Apply the condition to a let binding and return an Either
+    -- depending on whether it needs to be inlined or not.
+    docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
+    docond b = do
+      res <- condition b
+      return $ case res of True -> Left b; False -> Right b
+
+    -- Apply the given list of substitutions to the the given expression
+    do_substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> TransformMonad CoreExpr
+    do_substitute [] expr = return expr
+    do_substitute ((bndr, val):reps) expr = do
+      -- Perform this substitution in the expression
+      expr' <- substitute_clone bndr val context expr
+      -- And in the substitution values we will be using next
+      reps' <- mapM (subs_bind bndr val) reps
+      -- And then perform the remaining substitutions
+      do_substitute reps' expr'
+   
+    -- Replace the given binder with the given expression in the
+    -- expression oft the given let binding
+    subs_bind :: CoreBndr -> CoreExpr -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
+    subs_bind bndr expr (b, v) = do
+      v' <- substitute_clone  bndr expr (LetBinding:context) v
+      return (b, v')
+
+
 -- Leave all other expressions unchanged
 inlinebind _ context expr = return expr
 
 -- Leave all other expressions unchanged
 inlinebind _ context expr = return expr
 
@@ -201,17 +230,17 @@ isUserDefined bndr = str `notElem` builtinIds
   where
     str = Name.getOccString bndr
 
   where
     str = Name.getOccString bndr
 
--- Is the given binder normalizable? This means that its type signature can be
+-- Is the given binder normalizable? This means that its type signature can be
 -- represented in hardware, which should (?) guarantee that it can be made
 -- represented in hardware, which should (?) guarantee that it can be made
--- into hardware. Note that if a binder is not normalizable, it might become
--- so using argument propagation.
-isNormalizeable :: CoreBndr -> TransformMonad Bool 
-isNormalizeable bndr = Trans.lift (isNormalizeable' bndr)
-
-isNormalizeable' :: CoreBndr -> TranslatorSession Bool 
-isNormalizeable' bndr = do
+-- into hardware. This checks whether all the arguments and (optionally)
+-- the return value are
+-- representable.
+isNormalizeable :: 
+  Bool -- ^ Allow the result to be unrepresentable?
+  -> CoreBndr  -- ^ The binder to check
+  -> TranslatorSession Bool  -- ^ Is it normalizeable?
+isNormalizeable result_nonrep bndr = do
   let ty = Id.idType bndr
   let (arg_tys, res_ty) = Type.splitFunTys ty
   let ty = Id.idType bndr
   let (arg_tys, res_ty) = Type.splitFunTys ty
-  -- This function is normalizable if all its arguments and return value are
-  -- representable.
-  andM $ mapM isRepr' (res_ty:arg_tys)
+  let check_tys = if result_nonrep then arg_tys else (res_ty:arg_tys) 
+  andM $ mapM isRepr' check_tys