Merge git://github.com/darchon/clash into cλash
[matthijs/master-project/cλash.git] / Normalize.hs
index cceefc0cf7baf76c4135da6ef12cc7f9c5922988..93369213e5260f4a3661e1f67769713d00ec4c76 100644 (file)
@@ -43,7 +43,6 @@ import Pretty
 -- η abstraction
 --------------------------------
 eta, etatop :: Transform
-eta expr | is_fun expr && not (is_lam expr) = do
 eta expr | is_fun expr && not (is_lam expr) = do
   let arg_ty = (fst . Type.splitFunTy . CoreUtils.exprType) expr
   id <- mkInternalVar "param" arg_ty
@@ -70,6 +69,20 @@ beta expr = return expr
 -- Perform this transform everywhere
 betatop = everywhere ("beta", beta)
 
+--------------------------------
+-- Cast propagation
+--------------------------------
+-- Try to move casts as much downward as possible.
+castprop, castproptop :: Transform
+castprop (Cast (Let binds expr) ty) = change $ Let binds (Cast expr ty)
+castprop expr@(Cast (Case scrut b _ alts) ty) = change (Case scrut b ty alts')
+  where
+    alts' = map (\(con, bndrs, expr) -> (con, bndrs, (Cast expr ty))) alts
+-- Leave all other expressions unchanged
+castprop expr = return expr
+-- Perform this transform everywhere
+castproptop = everywhere ("castprop", castprop)
+
 --------------------------------
 -- let recursification
 --------------------------------
@@ -127,7 +140,7 @@ letflattop = everywhere ("letflat", letflat)
 --------------------------------
 -- Remove a = b bindings from let expressions everywhere
 letremovetop :: Transform
-letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> case e of (Var v) -> True; otherwise -> False))
+letremovetop = everywhere ("letremove", inlinebind (\(b, e) -> case e of (Var v) | not $ Id.isDataConWorkId v -> True; otherwise -> False))
 
 --------------------------------
 -- Function inlining
@@ -200,6 +213,7 @@ casewild expr@(Case scrut b ty alts) = do
       -- and binds that to b.
       mkextracts :: CoreBndr -> Int -> TransformMonad (Maybe (CoreBndr, CoreExpr))
       mkextracts b i =
+        -- TODO: Use free variables instead of is_wild. is_wild is a hack.
         if is_wild b || Type.isFunTy (Id.idType b) 
           -- Don't create extra bindings for binders that are already wild, or
           -- for binders that bind function types (to prevent loops with
@@ -271,12 +285,12 @@ caseremove expr = return expr
 caseremovetop = everywhere ("caseremove", caseremove)
 
 --------------------------------
--- Application simplification
+-- Argument extraction
 --------------------------------
--- Make sure that all arguments in an application are simple variables.
+-- Make sure that all arguments of a representable type are simple variables.
 appsimpl, appsimpltop :: Transform
--- Don't simplify arguments that are already simple
-appsimpl expr@(App f (Var _)) = return expr
+-- Don't simplify arguments that are already simple.
+appsimpl expr@(App f (Var v)) = return expr
 -- Simplify all non-applicable (to prevent loops with inlinefun) arguments,
 -- except for type arguments (since a let can't bind type vars, only a lambda
 -- can). Do this by introducing a new Let that binds the argument and passing
@@ -397,8 +411,56 @@ funprop expr = return expr
 -- Perform this transform everywhere
 funproptop = everywhere ("funprop", funprop)
 
+--------------------------------
+-- Function-typed argument extraction
+--------------------------------
+-- This transform takes any function-typed argument that cannot be propagated
+-- (because the function that is applied to it is a builtin function), and
+-- puts it in a brand new top level binder. This allows us to for example
+-- apply map to a lambda expression This will not conflict with inlinefun,
+-- since that only inlines local let bindings, not top level bindings.
+funextract, funextracttop :: Transform
+funextract expr@(App _ _) | is_var fexpr = do
+  body_maybe <- Trans.lift $ getGlobalBind f
+  case body_maybe of
+    -- We don't have a function body for f, so we can perform this transform.
+    Nothing -> do
+      -- Find the new arguments
+      args' <- mapM doarg args
+      -- And update the arguments. We use return instead of changed, so the
+      -- changed flag doesn't get set if none of the args got changed.
+      return $ MkCore.mkCoreApps fexpr args'
+    -- We have a function body for f, leave this application to funprop
+    Just _ -> return expr
+  where
+    -- Find the function called and the arguments
+    (fexpr, args) = collectArgs expr
+    Var f = fexpr
+    -- Change any arguments that have a function type, but are not simple yet
+    -- (ie, a variable or application). This means to create a new function
+    -- for map (\f -> ...) b, but not for map (foo a) b.
+    --
+    -- We could use is_applicable here instead of is_fun, but I think
+    -- arguments to functions could only have forall typing when existential
+    -- typing is enabled. Not sure, though.
+    doarg arg | not (is_simple arg) && is_fun arg = do
+      -- Create a new top level binding that binds the argument. Its body will
+      -- be extended with lambda expressions, to take any free variables used
+      -- by the argument expression.
+      let free_vars = VarSet.varSetElems $ CoreFVs.exprFreeVars arg
+      let body = MkCore.mkCoreLams free_vars arg
+      id <- mkBinderFor body "fun"
+      Trans.lift $ addGlobalBind id body
+      -- Replace the argument with a reference to the new function, applied to
+      -- all vars it uses.
+      change $ MkCore.mkCoreApps (Var id) (map Var free_vars)
+    -- Leave all other arguments untouched
+    doarg arg = return arg
 
--- TODO: introduce top level let if needed?
+-- Leave all other expressions unchanged
+funextract expr = return expr
+-- Perform this transform everywhere
+funextracttop = everywhere ("funextract", funextract)
 
 --------------------------------
 -- End of transformations
@@ -408,7 +470,7 @@ funproptop = everywhere ("funprop", funprop)
 
 
 -- What transforms to run?
-transforms = [typeproptop, funproptop, etatop, betatop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinefuntop, appsimpltop]
+transforms = [typeproptop, funproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, casewildtop, scrutsimpltop, casevalsimpltop, caseremovetop, inlinefuntop, appsimpltop]
 
 -- Turns the given bind into VHDL
 normalizeModule :: 
@@ -441,7 +503,7 @@ normalizeBind bndr =
       then
         -- This should really only happen at the top level... TODO: Give
         -- a different error if this happens down in the recursion.
-        error $ "Function " ++ show bndr ++ " is polymorphic, can't normalize"
+        error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
       else do
         normalized_funcs <- getA tsNormalized
         -- See if this function was normalized already
@@ -475,4 +537,4 @@ normalizeBind bndr =
                 return ()
               -- We don't have a value for this binder. This really shouldn't
               -- happen for local id's...
-              Nothing -> error $ "No value found for binder " ++ pprString bndr ++ "? This should not happen!"
+              Nothing -> error $ "\nNormalize.normalizeBind: No value found for binder " ++ pprString bndr ++ "? This should not happen!"