Add cast simplification normalization pass.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 8b35bb986bd4265df2ec58e6988fc21f04fe64de..bd0ec97e2b4a5be52acff046f81ecb7928ce77d7 100644 (file)
@@ -90,6 +90,30 @@ castprop expr = return expr
 -- Perform this transform everywhere
 castproptop = everywhere ("castprop", castprop)
 
+--------------------------------
+-- Cast simplification. Mostly useful for state packing and unpacking, but
+-- perhaps for others as well.
+--------------------------------
+castsimpl, castsimpltop :: Transform
+castsimpl expr@(Cast val ty) = do
+  -- Don't extract values that are already simpl
+  local_var <- Trans.lift $ is_local_var val
+  -- Don't extract values that are not representable, to prevent loops with
+  -- inlinenonrep
+  repr <- isRepr val
+  if (not local_var) && repr
+    then do
+      -- Generate a binder for the expression
+      id <- Trans.lift $ mkBinderFor val "castval"
+      -- Extract the expression
+      change $ Let (Rec [(id, val)]) (Cast (Var id) ty)
+    else
+      return expr
+-- Leave all other expressions unchanged
+castsimpl expr = return expr
+-- Perform this transform everywhere
+castsimpltop = everywhere ("castsimpl", castsimpl)
+
 --------------------------------
 -- let recursification
 --------------------------------
@@ -157,6 +181,24 @@ letflattop = everywhere ("letflat", letflat)
 letremovetop :: Transform
 letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> Trans.lift $ is_local_var e))
 
+--------------------------------
+-- Unused let binding removal
+--------------------------------
+letremoveunused, letremoveunusedtop :: Transform
+letremoveunused expr@(Let (Rec binds) res) = do
+  -- Filter out all unused binds.
+  let binds' = filter dobind binds
+  -- Only set the changed flag if binds got removed
+  changeif (length binds' /= length binds) (Let (Rec binds') res)
+    where
+      bound_exprs = map snd binds
+      -- For each bind check if the bind is used by res or any of the bound
+      -- expressions
+      dobind (bndr, _) = any (expr_uses_binders [bndr]) (res:bound_exprs)
+-- Leave all other expressions unchanged
+letremoveunused expr = return expr
+letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
+
 --------------------------------
 -- Function inlining
 --------------------------------
@@ -473,7 +515,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::
@@ -504,7 +546,7 @@ normalizeExpr what expr = do
       -- Normalize this expression
       trace ("Transforming " ++ what ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
       expr'' <- dotransforms transforms expr'
-      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr'')) $ return ()
       return expr''
 
 -- | Get the value that is bound to the given binder at top level. Fails when