Make top level inlining handle non-representable results gracefully.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index a5b2a9474d290dbb7fb6d0804c7f3c6622fb41b7..620482f703547f65f3dff7eb888742f70e67421d 100644 (file)
@@ -45,17 +45,16 @@ import CLasH.Utils.Pretty
 --------------------------------
 
 --------------------------------
--- η abstraction
---------------------------------
+-- η expansion
+--------------------------------
+-- Make sure all parameters to the normalized functions are named by top
+-- level lambda expressions. For this we apply η expansion to the
+-- function body (possibly enclosed in some lambda abstractions) while
+-- it has a function type. Eventually this will result in a function
+-- body consisting of a bunch of nested lambdas containing a
+-- non-function value (e.g., a complete application).
 eta, etatop :: Transform
--- Don't apply to expressions that are applied, since that would cause
--- us to apply to our own result indefinitely.
-eta (AppFirst:_) expr = return expr
--- Also don't apply to arguments, since this can cause loops with
--- funextract. This isn't the proper solution, but due to an
--- implementation bug in notappargs, this is how it used to work so far.
-eta (AppSecond:_) expr = return expr
-eta c expr | is_fun expr && not (is_lam expr) = do
+eta c expr | is_fun expr && not (is_lam expr) && all (== LambdaBody) c = do
   let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
   id <- Trans.lift $ mkInternalVar "param" arg_ty
   change (Lam id (App expr (Var id)))
@@ -121,86 +120,55 @@ castsimpl c expr = return expr
 -- Perform this transform everywhere
 castsimpltop = everywhere ("castsimpl", castsimpl)
 
-
 --------------------------------
--- Lambda simplication
+-- Return value simplification
 --------------------------------
--- Ensure that a lambda always evaluates to a let expressions or a simple
--- variable reference.
-lambdasimpl, lambdasimpltop :: Transform
--- Don't simplify a lambda that evaluates to let, since this is already
--- normal form (and would cause infinite loops).
-lambdasimpl c expr@(Lam _ (Let _ _)) = return expr
--- Put the of a lambda in its own binding, but not when the expression is
--- already a local variable, or not representable (to prevent loops with
--- inlinenonrep).
-lambdasimpl c expr@(Lam bndr res) = do
-  repr <- isRepr res
-  local_var <- Trans.lift $ is_local_var res
+-- Ensure the return value of a function follows proper normal form. eta
+-- expansion ensures the body starts with lambda abstractions, this
+-- transformation ensures that the lambda abstractions always contain a
+-- recursive let and that, when the return value is representable, the
+-- let contains a local variable reference in its body.
+retvalsimpl c expr | all (== LambdaBody) c && not (is_lam expr) && not (is_let expr) = do
+  local_var <- Trans.lift $ is_local_var expr
+  repr <- isRepr expr
+  if not local_var && repr
+    then do
+      id <- Trans.lift $ mkBinderFor expr "res" 
+      change $ Let (Rec [(id, expr)]) (Var id)
+    else
+      return expr
+
+retvalsimpl c expr@(Let (Rec binds) body) | all (== LambdaBody) c = do
+  -- Don't extract values that are already a local variable, to prevent
+  -- loops with ourselves.
+  local_var <- Trans.lift $ is_local_var body
+  -- Don't extract values that are not representable, to prevent loops with
+  -- inlinenonrep
+  repr <- isRepr body
   if not local_var && repr
     then do
-      id <- Trans.lift $ mkBinderFor res "res"
-      change $ Lam bndr (Let (NonRec id res) (Var id))
+      id <- Trans.lift $ mkBinderFor body "res" 
+      change $ Let (Rec ((id, body):binds)) (Var id)
     else
-      -- If the result is already a local var or not representable, don't
-      -- extract it.
       return expr
 
+
 -- Leave all other expressions unchanged
-lambdasimpl c expr = return expr
+retvalsimpl c expr = return expr
 -- Perform this transform everywhere
-lambdasimpltop = everywhere ("lambdasimpl", lambdasimpl)
+retvalsimpltop = everywhere ("retvalsimpl", retvalsimpl)
 
 --------------------------------
 -- let derecursification
 --------------------------------
-letderec, letderectop :: Transform
-letderec c expr@(Let (Rec binds) res) = case liftable of
-  -- Nothing is liftable, just return
-  [] -> return expr
-  -- Something can be lifted, generate a new let expression
-  _ -> change $ mkNonRecLets liftable (Let (Rec nonliftable) res)
-  where
-    -- Make a list of all the binders bound in this recursive let
-    bndrs = map fst binds
-    -- See which bindings are liftable
-    (liftable, nonliftable) = List.partition canlift binds
-    -- Any expression that does not use any of the binders in this recursive let
-    -- can be lifted into a nonrec let. It can't use its own binder either,
-    -- since that would mean the binding is self-recursive and should be in a
-    -- single bind recursive let.
-    canlift (bndr, e) = not $ expr_uses_binders bndrs e
--- Leave all other expressions unchanged
-letderec c expr = return expr
--- Perform this transform everywhere
-letderectop = everywhere ("letderec", letderec)
-
---------------------------------
--- let simplification
---------------------------------
-letsimpl, letsimpltop :: Transform
--- Don't simplify a let that evaluates to another let, since this is already
--- normal form (and would cause infinite loops with letflat below).
-letsimpl c expr@(Let _ (Let _ _)) = return expr
--- Put the "in ..." value of a let in its own binding, but not when the
--- expression is already a local variable, or not representable (to prevent loops with inlinenonrep).
-letsimpl c expr@(Let binds res) = do
-  repr <- isRepr res
-  local_var <- Trans.lift $ is_local_var res
-  if not local_var && repr
-    then do
-      -- If the result is not a local var already (to prevent loops with
-      -- ourselves), extract it.
-      id <- Trans.lift $ mkBinderFor res "foo"
-      change $ Let binds (Let (NonRec id  res) (Var id))
-    else
-      -- If the result is already a local var, don't extract it.
-      return expr
+letrec, letrectop :: Transform
+letrec c expr@(Let (NonRec bndr val) res) = 
+  change $ Let (Rec [(bndr, val)]) res
 
 -- Leave all other expressions unchanged
-letsimpl c expr = return expr
+letrec c expr = return expr
 -- Perform this transform everywhere
-letsimpltop = everywhere ("letsimpl", letsimpl)
+letrectop = everywhere ("letrec", letrec)
 
 --------------------------------
 -- let flattening
@@ -327,44 +295,46 @@ inlinenonreptop = everywhere ("inlinenonrep", inlinebind ((Monad.liftM not) . is
 --------------------------------
 -- Top level function inlining
 --------------------------------
--- This transformation inlines top level bindings that have been generated by
--- the compiler and are really simple. Really simple currently means that the
--- normalized form only contains a single binding, which catches most of the
+-- This transformation inlines simple top level bindings. Simple
+-- currently means that the body is only a single application (though
+-- the complexity of the arguments is not currently checked) or that the
+-- normalized form only contains a single binding. This should catch most of the
 -- cases where a top level function is created that simply calls a type class
 -- method with a type and dictionary argument, e.g.
 --   fromInteger = GHC.Num.fromInteger (SizedWord D8) $dNum
 -- which is later called using simply
 --   fromInteger (smallInteger 10)
--- By inlining such calls to simple, compiler generated functions, we prevent
--- huge amounts of trivial components in the VHDL output, which the user never
--- wanted. We never inline user-defined functions, since we want to preserve
--- all structure defined by the user. Currently this includes all functions
--- that were created by funextract, since we would get loops otherwise.
 --
--- Note that "defined by the compiler" isn't completely watertight, since GHC
--- doesn't seem to set all those names as "system names", we apply some
--- guessing here.
+-- These useless wrappers are created by GHC automatically. If we don't
+-- inline them, we get loads of useless components cluttering the
+-- generated VHDL.
+--
+-- Note that the inlining could also inline simple functions defined by
+-- the user, not just GHC generated functions. It turns out to be near
+-- impossible to reliably determine what functions are generated and
+-- what functions are user-defined. Instead of guessing (which will
+-- inline less than we want) we will just inline all simple functions.
+--
+-- Only functions that are actually completely applied and bound by a
+-- variable in a let expression are inlined. These are the expressions
+-- that will eventually generate instantiations of trivial components.
+-- By not inlining any other reference, we also prevent looping problems
+-- with funextract and inlinedict.
 inlinetoplevel, inlinetopleveltop :: Transform
--- HACK: Don't inline == and /=. The default (derived) implementation
--- for /= uses the polymorphic version of ==, which gets a dictionary
--- for Eq passed in, which contains a reference to itself, resulting in
--- an infinite loop in transformation. Not inlining == is really a hack,
--- but for now it keeps things working with the most common symptom of
--- this problem.
-inlinetoplevel c expr@(Var f) | Name.getOccString f `elem` ["==", "/="] = return expr
--- Any system name is candidate for inlining. Never inline user-defined
--- functions, to preserve structure.
-inlinetoplevel c expr@(Var f) | not $ isUserDefined f = do
-  body_maybe <- needsInline f
-  case body_maybe of
-    Just body -> do
-        -- Regenerate all uniques in the to-be-inlined expression
-        body_uniqued <- Trans.lift $ genUniques body
-        -- And replace the variable reference with the unique'd body.
-        change body_uniqued
-        -- No need to inline
-    Nothing -> return expr
-
+inlinetoplevel (LetBinding:_) expr | not (is_fun expr) =
+  case collectArgs expr of
+       (Var f, args) -> do
+         body_maybe <- needsInline f
+         case body_maybe of
+               Just body -> do
+                       -- Regenerate all uniques in the to-be-inlined expression
+                       body_uniqued <- Trans.lift $ genUniques body
+                       -- And replace the variable reference with the unique'd body.
+                       change (mkApps body_uniqued args)
+                       -- No need to inline
+               Nothing -> return expr
+       -- This is not an application of a binder, leave it unchanged.
+       _ -> return expr
 
 -- Leave all other expressions unchanged
 inlinetoplevel c expr = return expr
@@ -388,23 +358,35 @@ needsInline f = do
         case norm_maybe of
           -- Noth normalizeable
           Nothing -> return Nothing 
-          Just norm -> case splitNormalized norm of
+          Just norm -> case splitNormalizedNonRep norm of
             -- The function has just a single binding, so that's simple
             -- enough to inline.
-            (args, [bind], res) -> return $ Just norm
+            (args, [bind], Var res) -> return $ Just norm
             -- More complicated function, don't inline
             _ -> return Nothing
             
 --------------------------------
 -- Dictionary inlining
 --------------------------------
--- Inline all top level dictionaries, so we can use them to resolve
--- class methods based on the dictionary passed. 
-inlinedict c expr@(Var f) | Id.isDictId f = do
-  body_maybe <- Trans.lift $ getGlobalBind f
+-- Inline all top level dictionaries, that are in a position where
+-- classopresolution can actually resolve them. This makes this
+-- transformation look similar to classoperesolution below, but we'll
+-- keep them separated for clarity. By not inlining other dictionaries,
+-- we prevent expression sizes exploding when huge type level integer
+-- dictionaries are inlined which can never be expanded (in casts, for
+-- example).
+inlinedict c expr@(App (App (Var sel) ty) (Var dict)) | not is_builtin && is_classop = do
+  body_maybe <- Trans.lift $ getGlobalBind dict
   case body_maybe of
+    -- No body available (no source available, or a local variable /
+    -- argument)
     Nothing -> return expr
-    Just body -> change body
+    Just body -> change (App (App (Var sel) ty) body)
+  where
+    -- Is this a builtin function / method?
+    is_builtin = elem (Name.getOccString sel) builtinIds
+    -- Are we dealing with a class operation selector?
+    is_classop = Maybe.isJust (Id.isClassOpId_maybe sel)
 
 -- Leave all other expressions unchanged
 inlinedict c expr = return expr
@@ -806,25 +788,6 @@ funextract c expr = return expr
 -- Perform this transform everywhere
 funextracttop = everywhere ("funextract", funextract)
 
---------------------------------
--- Ensure that a function that just returns another function (or rather,
--- another top-level binder) is still properly normalized. This is a temporary
--- solution, we should probably integrate this pass with lambdasimpl and
--- letsimpl instead.
---------------------------------
-simplrestop c expr@(Lam _ _) = return expr
-simplrestop c expr@(Let _ _) = return expr
-simplrestop c expr = do
-  local_var <- Trans.lift $ is_local_var expr
-  -- Don't extract values that are not representable, to prevent loops with
-  -- inlinenonrep
-  repr <- isRepr expr
-  if local_var || not repr
-    then
-      return expr
-    else do
-      id <- Trans.lift $ mkBinderFor expr "res" 
-      change $ Let (NonRec id expr) (Var id)
 --------------------------------
 -- End of transformations
 --------------------------------
@@ -833,7 +796,7 @@ simplrestop c expr = do
 
 
 -- What transforms to run?
-transforms = [inlinedicttop, inlinetopleveltop, classopresolutiontop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
+transforms = [inlinedicttop, inlinetopleveltop, classopresolutiontop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letrectop, letremovetop, retvalsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function, or an error
 -- if it is not a known global binder.
@@ -878,22 +841,32 @@ normalizeExpr ::
   -> TranslatorSession CoreSyn.CoreExpr -- ^ The normalized expression
 
 normalizeExpr what expr = do
+      startcount <- MonadState.get tsTransformCounter 
       expr_uniqued <- genUniques expr
       -- Normalize this expression
       trace (what ++ " before normalization:\n\n" ++ showSDoc ( ppr expr_uniqued ) ++ "\n") $ return ()
       expr' <- dotransforms transforms expr_uniqued
-      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return ()
-      return expr'
+      endcount <- MonadState.get tsTransformCounter 
+      trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')
+             ++ "\nNeeded " ++ show (endcount - startcount) ++ " transformations to normalize " ++ what) $
+       return expr'
 
 -- | Split a normalized expression into the argument binders, top level
---   bindings and the result binder.
+--   bindings and the result binder. This function returns an error if
+--   the type of the expression is not representable.
 splitNormalized ::
   CoreExpr -- ^ The normalized expression
   -> ([CoreBndr], [Binding], CoreBndr)
-splitNormalized expr = (args, binds, res)
+splitNormalized expr = 
+  case splitNormalizedNonRep expr of
+    (args, binds, Var res) -> (args, binds, res)
+    _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"
+
+-- Split a normalized expression, whose type can be unrepresentable.
+splitNormalizedNonRep::
+  CoreExpr -- ^ The normalized expression
+  -> ([CoreBndr], [Binding], CoreExpr)
+splitNormalizedNonRep expr = (args, binds, resexpr)
   where
     (args, letexpr) = CoreSyn.collectBinders expr
     (binds, resexpr) = flattenLets letexpr
-    res = case resexpr of 
-      (Var x) -> x
-      _ -> error $ "Normalize.splitNormalized: Not in normal form: " ++ pprString expr ++ "\n"