Add cast simplification normalization pass.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index b2b4bd86f0080693d29f70183ef469083b15bdfc..bd0ec97e2b4a5be52acff046f81ecb7928ce77d7 100644 (file)
@@ -90,6 +90,30 @@ castprop expr = return expr
 -- Perform this transform everywhere
 castproptop = everywhere ("castprop", castprop)
 
+--------------------------------
+-- Cast simplification. Mostly useful for state packing and unpacking, but
+-- perhaps for others as well.
+--------------------------------
+castsimpl, castsimpltop :: Transform
+castsimpl expr@(Cast val ty) = do
+  -- Don't extract values that are already simpl
+  local_var <- Trans.lift $ is_local_var val
+  -- Don't extract values that are not representable, to prevent loops with
+  -- inlinenonrep
+  repr <- isRepr val
+  if (not local_var) && repr
+    then do
+      -- Generate a binder for the expression
+      id <- Trans.lift $ mkBinderFor val "castval"
+      -- Extract the expression
+      change $ Let (Rec [(id, val)]) (Cast (Var id) ty)
+    else
+      return expr
+-- Leave all other expressions unchanged
+castsimpl expr = return expr
+-- Perform this transform everywhere
+castsimpltop = everywhere ("castsimpl", castsimpl)
+
 --------------------------------
 -- let recursification
 --------------------------------
@@ -170,7 +194,7 @@ letremoveunused expr@(Let (Rec binds) res) = do
       bound_exprs = map snd binds
       -- For each bind check if the bind is used by res or any of the bound
       -- expressions
-      dobind (bndr, _) = not $ any (expr_uses_binders [bndr]) (res:bound_exprs)
+      dobind (bndr, _) = any (expr_uses_binders [bndr]) (res:bound_exprs)
 -- Leave all other expressions unchanged
 letremoveunused expr = return expr
 letremoveunusedtop = everywhere ("letremoveunused", letremoveunused)
@@ -491,7 +515,7 @@ funextracttop = everywhere ("funextract", funextract)
 
 
 -- What transforms to run?
-transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop]
+transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop]
 
 -- | Returns the normalized version of the given function.
 getNormalized ::