Make vhdl generation and normalization lazy.
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index 8ec195b0ef936aadd89449571988da9e3c4f56e0..7571a6f3b0fbf21b33673fe59476bd217ab336ed 100644 (file)
@@ -4,7 +4,7 @@
 -- top level function "normalize", and defines the actual transformation passes that
 -- are performed.
 --
-module CLasH.Normalize (normalizeModule) where
+module CLasH.Normalize (getNormalized) where
 
 -- Standard modules
 import Debug.Trace
@@ -34,8 +34,10 @@ import Outputable ( showSDoc, ppr, nest )
 
 -- Local imports
 import CLasH.Normalize.NormalizeTypes
+import CLasH.Translator.TranslatorTypes
 import CLasH.Normalize.NormalizeTools
 import CLasH.VHDL.VHDLTypes
+import qualified CLasH.Utils as Utils
 import CLasH.Utils.Core.CoreTools
 import CLasH.Utils.Pretty
 
@@ -472,78 +474,36 @@ funextracttop = everywhere ("funextract", funextract)
 -- What transforms to run?
 transforms = [argproptop, funextracttop, etatop, betatop, castproptop, letremovetop, letrectop, letsimpltop, letflattop, scrutsimpltop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop]
 
--- Turns the given bind into VHDL
-normalizeModule ::
-  HscTypes.HscEnv
-  -> UniqSupply.UniqSupply -- ^ A UniqSupply we can use
-  -> [(CoreBndr, CoreExpr)]  -- ^ All bindings we know (i.e., in the current module)
-  -> [CoreExpr]
-  -> [CoreBndr]  -- ^ The bindings to generate VHDL for (i.e., the top level bindings)
-  -> [Bool] -- ^ For each of the bindings to generate VHDL for, if it is stateful
-  -> ([(CoreBndr, CoreExpr)], [(CoreBndr, CoreExpr)], TypeState) -- ^ The resulting VHDL
+-- | Returns the normalized version of the given function.
+getNormalized ::
+  CoreBndr -- ^ The function to get
+  -> TranslatorSession CoreExpr -- The normalized function body
 
-normalizeModule env uniqsupply bindings testexprs generate_for statefuls = runTransformSession env uniqsupply $ do
-  testbinds <- mapM (\x -> do { v <- mkBinderFor' x "test" ; return (v,x) } ) testexprs
-  let testbinders = (map fst testbinds)
-  -- Put all the bindings in this module in the tsBindings map
-  putA tsBindings (Map.fromList (bindings ++ testbinds))
-  -- (Recursively) normalize each of the requested bindings
-  mapM normalizeBind (generate_for ++ testbinders)
-  -- Get all initial bindings and the ones we produced
-  bindings_map <- getA tsBindings
-  let bindings = Map.assocs bindings_map
-  normalized_binders' <- getA tsNormalized
-  let normalized_binders = VarSet.delVarSetList normalized_binders' testbinders
-  let ret_testbinds = zip testbinders (Maybe.catMaybes $ map (\x -> lookup x bindings) testbinders)
-  let ret_binds = filter ((`VarSet.elemVarSet` normalized_binders) . fst) bindings
-  typestate <- getA tsType
-  -- But return only the normalized bindings
-  return $ (ret_binds, ret_testbinds, typestate)
+getNormalized bndr = Utils.makeCached bndr tsNormalized $ do
+  if is_poly (Var bndr)
+    then
+      -- This should really only happen at the top level... TODO: Give
+      -- a different error if this happens down in the recursion.
+      error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
+    else do
+      expr <- getBinding bndr
+      -- Introduce an empty Let at the top level, so there will always be
+      -- a let in the expression (none of the transformations will remove
+      -- the last let).
+      let expr' = Let (Rec []) expr
+      -- Normalize this expression
+      trace ("Transforming " ++ (show bndr) ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
+      expr'' <- dotransforms transforms expr'
+      trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
+      return expr''
 
-normalizeBind :: CoreBndr -> TransformSession ()
-normalizeBind bndr =
-  -- Don't normalize global variables, these should be either builtin
-  -- functions or data constructors.
-  Monad.when (Var.isLocalId bndr) $ do
-    -- Skip binders that have a polymorphic type, since it's impossible to
-    -- create polymorphic hardware.
-    if is_poly (Var bndr)
-      then
-        -- This should really only happen at the top level... TODO: Give
-        -- a different error if this happens down in the recursion.
-        error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
-      else do
-        normalized_funcs <- getA tsNormalized
-        -- See if this function was normalized already
-        if VarSet.elemVarSet bndr normalized_funcs
-          then
-            -- Yup, don't do it again
-            return ()
-          else do
-            -- Nope, note that it has been and do it.
-            modA tsNormalized (flip VarSet.extendVarSet bndr)
-            expr_maybe <- getGlobalBind bndr
-            case expr_maybe of 
-              Just expr -> do
-                -- Introduce an empty Let at the top level, so there will always be
-                -- a let in the expression (none of the transformations will remove
-                -- the last let).
-                let expr' = Let (Rec []) expr
-                -- Normalize this expression
-                trace ("Transforming " ++ (show bndr) ++ "\nBefore:\n\n" ++ showSDoc ( ppr expr' ) ++ "\n") $ return ()
-                expr' <- dotransforms transforms expr'
-                trace ("\nAfter:\n\n" ++ showSDoc ( ppr expr')) $ return ()
-                -- And store the normalized version in the session
-                modA tsBindings (Map.insert bndr expr')
-                -- Find all vars used with a function type. All of these should be global
-                -- binders (i.e., functions used), since any local binders with a function
-                -- type should have been inlined already.
-                bndrs <- getGlobalBinders
-                let used_funcs_set = CoreFVs.exprSomeFreeVars (\v -> not (Id.isDictId v) && v `elem` bndrs) expr'
-                let used_funcs = VarSet.varSetElems used_funcs_set
-                -- Process each of the used functions recursively
-                mapM normalizeBind used_funcs
-                return ()
-              -- We don't have a value for this binder. This really shouldn't
-              -- happen for local id's...
-              Nothing -> error $ "\nNormalize.normalizeBind: No value found for binder " ++ pprString bndr ++ "? This should not happen!"
+-- | Get the value that is bound to the given binder at top level. Fails when
+--   there is no such binding.
+getBinding ::
+  CoreBndr -- ^ The binder to get the expression for
+  -> TranslatorSession CoreExpr -- ^ The value bound to the binder
+
+getBinding bndr = Utils.makeCached bndr tsBindings $ do
+  -- If the binding isn't in the "cache" (bindings map), then we can't create
+  -- it out of thin air, so return an error.
+  error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr