Move the application of "everywhere" to dotransforms.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize / NormalizeTools.hs
index c774a335f34f01b65fd9f6b1c51c17f29655702c..d9d4bd34e97df38f71d247d5d9ecd5707fef4665 100644 (file)
@@ -6,6 +6,7 @@ module CLasH.Normalize.NormalizeTools where
 
 -- Standard modules
 import qualified Data.Monoid as Monoid
+import qualified Data.Either as Either
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
 import qualified "transformers" Control.Monad.Trans as Trans
@@ -45,10 +46,13 @@ applyboth first (name, second) context expr = do
   if Monoid.getAny $
         -- trace ("Trying to apply transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n")
         changed 
-    then 
+    then
      -- trace ("Applying transform " ++ name ++ " to:\n" ++ showSDoc (nest 4 $ ppr expr') ++ "\nType: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr') ++ "\n"
+     --        ++ "Context: " ++ show context ++ "\n"
      --        ++ "Result of applying " ++ name ++ ":\n" ++ showSDoc (nest 4 $ ppr expr'') ++ "\n" ++ "Type: \n" ++ (showSDoc $ nest 4 $ ppr $ CoreUtils.exprType expr'') ++ "\n" ) $
-      applyboth first (name, second) context expr'' 
+      do
+        Trans.lift $ MonadState.modify tsTransformCounter (+1)
+        applyboth first (name, second) context expr'' 
     else 
       -- trace ("No changes") $
       return expr''
@@ -62,22 +66,22 @@ subeverywhere trans c (App a b) = do
   return $ App a' b'
 
 subeverywhere trans c (Let (NonRec b bexpr) expr) = do
-  bexpr' <- trans (Other:c) bexpr
-  expr' <- trans (Other:c) expr
+  bexpr' <- trans (LetBinding:c) bexpr
+  expr' <- trans (LetBody:c) expr
   return $ Let (NonRec b bexpr') expr'
 
 subeverywhere trans c (Let (Rec binds) expr) = do
-  expr' <- trans (Other:c) expr
+  expr' <- trans (LetBody:c) expr
   binds' <- mapM transbind binds
   return $ Let (Rec binds') expr'
   where
     transbind :: (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
     transbind (b, e) = do
-      e' <- trans (Other:c) e
+      e' <- trans (LetBinding:c) e
       return (b, e')
 
 subeverywhere trans c (Lam x expr) = do
-  expr' <- trans (Other:c) expr
+  expr' <- trans (LambdaBody:c) expr
   return $ Lam x expr'
 
 subeverywhere trans c (Case scrut b t alts) = do
@@ -100,41 +104,51 @@ subeverywhere trans c (Cast expr ty) = do
 
 subeverywhere trans c expr = error $ "\nNormalizeTools.subeverywhere: Unsupported expression: " ++ show expr
 
--- Apply the given transformation to all expressions, except for direct
--- arguments of an application
-notappargs :: (String, Transform) -> Transform
-notappargs trans = applyboth (subnotappargs trans) trans
-
--- Apply the given transformation to all (direct and indirect) subexpressions
--- (but not the expression itself), except for direct arguments of an
--- application
-subnotappargs :: (String, Transform) -> Transform
-subnotappargs trans c (App a b) = do
-  a' <- subnotappargs trans (Other:c) a
-  b' <- subnotappargs trans (Other:c) b
-  return $ App a' b'
-
--- Let subeverywhere handle all other expressions
-subnotappargs trans c expr = subeverywhere (notappargs trans) c expr
-
 -- Runs each of the transforms repeatedly inside the State monad.
-dotransforms :: [Transform] -> CoreExpr -> TranslatorSession CoreExpr
+dotransforms :: [(String, Transform)] -> CoreExpr -> TranslatorSession CoreExpr
 dotransforms transs expr = do
-  (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> trans [] e) expr transs
+  (expr', changed) <- Writer.runWriterT $ Monad.foldM (\e trans -> everywhere trans [] e) expr transs
   if Monoid.getAny changed then dotransforms transs expr' else return expr'
 
 -- Inline all let bindings that satisfy the given condition
 inlinebind :: ((CoreBndr, CoreExpr) -> TransformMonad Bool) -> Transform
-inlinebind condition context expr@(Let (NonRec bndr expr') res) = do
-    applies <- condition (bndr, expr')
-    if applies
-      then do
-        -- Substitute the binding in res and return that
-        res' <- substitute_clone bndr expr' context res
-        change res'
-      else
-        -- Don't change this let
-        return expr
+inlinebind condition context expr@(Let (Rec binds) res) = do
+    -- Find all bindings that adhere to the condition
+    res_eithers <- mapM docond binds
+    case Either.partitionEithers res_eithers of
+      -- No replaces? No change
+      ([], _) -> return expr
+      (replace, others) -> do
+        -- Substitute the to be replaced binders with their expression
+        newexpr <- do_substitute replace (Let (Rec others) res)
+        change newexpr
+  where 
+    -- Apply the condition to a let binding and return an Either
+    -- depending on whether it needs to be inlined or not.
+    docond :: (CoreBndr, CoreExpr) -> TransformMonad (Either (CoreBndr, CoreExpr) (CoreBndr, CoreExpr))
+    docond b = do
+      res <- condition b
+      return $ case res of True -> Left b; False -> Right b
+
+    -- Apply the given list of substitutions to the the given expression
+    do_substitute :: [(CoreBndr, CoreExpr)] -> CoreExpr -> TransformMonad CoreExpr
+    do_substitute [] expr = return expr
+    do_substitute ((bndr, val):reps) expr = do
+      -- Perform this substitution in the expression
+      expr' <- substitute_clone bndr val context expr
+      -- And in the substitution values we will be using next
+      reps' <- mapM (subs_bind bndr val) reps
+      -- And then perform the remaining substitutions
+      do_substitute reps' expr'
+   
+    -- Replace the given binder with the given expression in the
+    -- expression oft the given let binding
+    subs_bind :: CoreBndr -> CoreExpr -> (CoreBndr, CoreExpr) -> TransformMonad (CoreBndr, CoreExpr)
+    subs_bind bndr expr (b, v) = do
+      v' <- substitute_clone  bndr expr (LetBinding:context) v
+      return (b, v')
+
+
 -- Leave all other expressions unchanged
 inlinebind _ context expr = return expr
 
@@ -201,17 +215,17 @@ isUserDefined bndr = str `notElem` builtinIds
   where
     str = Name.getOccString bndr
 
--- Is the given binder normalizable? This means that its type signature can be
+-- Is the given binder normalizable? This means that its type signature can be
 -- represented in hardware, which should (?) guarantee that it can be made
--- into hardware. Note that if a binder is not normalizable, it might become
--- so using argument propagation.
-isNormalizeable :: CoreBndr -> TransformMonad Bool 
-isNormalizeable bndr = Trans.lift (isNormalizeable' bndr)
-
-isNormalizeable' :: CoreBndr -> TranslatorSession Bool 
-isNormalizeable' bndr = do
+-- into hardware. This checks whether all the arguments and (optionally)
+-- the return value are
+-- representable.
+isNormalizeable :: 
+  Bool -- ^ Allow the result to be unrepresentable?
+  -> CoreBndr  -- ^ The binder to check
+  -> TranslatorSession Bool  -- ^ Is it normalizeable?
+isNormalizeable result_nonrep bndr = do
   let ty = Id.idType bndr
   let (arg_tys, res_ty) = Type.splitFunTys ty
-  -- This function is normalizable if all its arguments and return value are
-  -- representable.
-  andM $ mapM isRepr' (res_ty:arg_tys)
+  let check_tys = if result_nonrep then arg_tys else (res_ty:arg_tys) 
+  andM $ mapM isRepr' check_tys