Merge branch 'master' of http://git.stderr.nl/matthijs/master-project/cλash
[matthijs/master-project/cλash.git] / cλash / CLasH / Normalize.hs
index f0f2de2fccba3187092c4958bc1e3382cc0c0682..31474d59328bd0342fc6244262061eba736b8c80 100644 (file)
@@ -13,7 +13,9 @@ import qualified List
 import qualified "transformers" Control.Monad.Trans as Trans
 import qualified Control.Monad as Monad
 import qualified Control.Monad.Trans.Writer as Writer
+import qualified Data.Accessor.Monad.Trans.State as MonadState
 import qualified Data.Monoid as Monoid
+import qualified Data.Map as Map
 
 -- GHC API
 import CoreSyn
@@ -21,8 +23,10 @@ import qualified CoreUtils
 import qualified Type
 import qualified Id
 import qualified Var
+import qualified Name
 import qualified VarSet
 import qualified CoreFVs
+import qualified Class
 import qualified MkCore
 import Outputable ( showSDoc, ppr, nest )
 
@@ -30,6 +34,7 @@ import Outputable ( showSDoc, ppr, nest )
 import CLasH.Normalize.NormalizeTypes
 import CLasH.Translator.TranslatorTypes
 import CLasH.Normalize.NormalizeTools
+import CLasH.VHDL.Constants (builtinIds)
 import qualified CLasH.Utils as Utils
 import CLasH.Utils.Core.CoreTools
 import CLasH.Utils.Core.BinderTools
@@ -333,38 +338,134 @@ inlinenonreptop = everywhere ("inlinenonrep", inlinebind ((Monad.liftM not) . is
 -- doesn't seem to set all those names as "system names", we apply some
 -- guessing here.
 inlinetoplevel, inlinetopleveltop :: Transform
+-- HACK: Don't inline == and /=. The default (derived) implementation
+-- for /= uses the polymorphic version of ==, which gets a dictionary
+-- for Eq passed in, which contains a reference to itself, resulting in
+-- an infinite loop in transformation. Not inlining == is really a hack,
+-- but for now it keeps things working with the most common symptom of
+-- this problem.
+inlinetoplevel expr@(Var f) | Name.getOccString f `elem` ["==", "/="] = return expr
 -- Any system name is candidate for inlining. Never inline user-defined
 -- functions, to preserve structure.
 inlinetoplevel expr@(Var f) | not $ isUserDefined f = do
-  norm <- isNormalizeable f
-  -- See if this is a top level binding for which we have a body
-  body_maybe <- Trans.lift $ getGlobalBind f
-  if norm && Maybe.isJust body_maybe
-    then do
-      -- Get the normalized version
-      norm <- Trans.lift $ getNormalized f
-      if needsInline norm 
-        then do
-          -- Regenerate all uniques in the to-be-inlined expression
-          norm_uniqued <- Trans.lift $ genUniques norm
-          change norm_uniqued
-        else
-          return expr
-    else
-      -- No body or not normalizeable.
-      return expr
+  body_maybe <- needsInline f
+  case body_maybe of
+    Just body -> do
+        -- Regenerate all uniques in the to-be-inlined expression
+        body_uniqued <- Trans.lift $ genUniques body
+        -- And replace the variable reference with the unique'd body.
+        change body_uniqued
+        -- No need to inline
+    Nothing -> return expr
+
+
 -- Leave all other expressions unchanged
 inlinetoplevel expr = return expr
 inlinetopleveltop = everywhere ("inlinetoplevel", inlinetoplevel)
+  
+-- | Does the given binder need to be inlined? If so, return the body to
+-- be used for inlining.
+needsInline :: CoreBndr -> TransformMonad (Maybe CoreExpr)
+needsInline f = do
+  body_maybe <- Trans.lift $ getGlobalBind f
+  case body_maybe of
+    -- No body available?
+    Nothing -> return Nothing
+    Just body -> case CoreSyn.collectArgs body of
+      -- The body is some (top level) binder applied to 0 or more
+      -- arguments. That should be simple enough to inline.
+      (Var f, args) -> return $ Just body
+      -- Body is more complicated, try normalizing it
+      _ -> do
+        norm_maybe <- Trans.lift $ getNormalized_maybe f
+        case norm_maybe of
+          -- Noth normalizeable
+          Nothing -> return Nothing 
+          Just norm -> case splitNormalized norm of
+            -- The function has just a single binding, so that's simple
+            -- enough to inline.
+            (args, [bind], res) -> return $ Just norm
+            -- More complicated function, don't inline
+            _ -> return Nothing
+            
+--------------------------------
+-- Dictionary inlining
+--------------------------------
+-- Inline all top level dictionaries, so we can use them to resolve
+-- class methods based on the dictionary passed. 
+inlinedict expr@(Var f) | Id.isDictId f = do
+  body_maybe <- Trans.lift $ getGlobalBind f
+  case body_maybe of
+    Nothing -> return expr
+    Just body -> change body
 
-needsInline :: CoreExpr -> Bool
-needsInline expr = case splitNormalized expr of
-  -- Inline any function that only has a single definition, it is probably
-  -- simple enough. This might inline some stuff that it shouldn't though it
-  -- will never inline user-defined functions (inlinetoplevel only tries
-  -- system names) and inlining should never break things.
-  (args, [bind], res) -> True
-  _ -> False
+-- Leave all other expressions unchanged
+inlinedict expr = return expr
+inlinedicttop = everywhere ("inlinedict", inlinedict)
+
+--------------------------------
+-- ClassOp resolution
+--------------------------------
+-- Resolves any class operation to the actual operation whenever
+-- possible. Class methods (as well as parent dictionary selectors) are
+-- special "functions" that take a type and a dictionary and evaluate to
+-- the corresponding method. A dictionary is nothing more than a
+-- special dataconstructor applied to the type the dictionary is for,
+-- each of the superclasses and all of the class method definitions for
+-- that particular type. Since dictionaries all always inlined (top
+-- levels dictionaries are inlined by inlinedict, local dictionaries are
+-- inlined by inlinenonrep), we will eventually have something like:
+--
+--   baz
+--     @ CLasH.HardwareTypes.Bit
+--     (D:Baz @ CLasH.HardwareTypes.Bit bitbaz)
+--
+-- Here, baz is the method selector for the baz method, while
+-- D:Baz is the dictionary constructor for the Baz and bitbaz is the baz
+-- method defined in the Baz Bit instance declaration.
+--
+-- To resolve this, we can look at the ClassOp IdInfo from the baz Id,
+-- which contains the Class it is defined for. From the Class, we can
+-- get a list of all selectors (both parent class selectors as well as
+-- method selectors). Since the arguments to D:Baz (after the type
+-- argument) correspond exactly to this list, we then look up baz in
+-- that list and replace the entire expression by the corresponding 
+-- argument to D:Baz.
+--
+-- We don't resolve methods that have a builtin translation (such as
+-- ==), since the actual implementation is not always (easily)
+-- translateable. For example, when deriving ==, GHC generates code
+-- using $con2tag functions to translate a datacon to an int and compare
+-- that with GHC.Prim.==# . Better to avoid that for now.
+classopresolution, classopresolutiontop :: Transform
+classopresolution expr@(App (App (Var sel) ty) dict) | not is_builtin =
+  case Id.isClassOpId_maybe sel of
+    -- Not a class op selector
+    Nothing -> return expr
+    Just cls -> case collectArgs dict of
+      (_, []) -> return expr -- Dict is not an application (e.g., not inlined yet)
+      (Var dictdc, (ty':selectors)) | not (Maybe.isJust (Id.isDataConId_maybe dictdc)) -> return expr -- Dictionary is not a datacon yet (but e.g., a top level binder)
+                                | tyargs_neq ty ty' -> error $ "Normalize.classopresolution: Applying class selector to dictionary without matching type?\n" ++ pprString expr
+                                | otherwise ->
+        let selector_ids = Class.classSelIds cls in
+        -- Find the selector used in the class' list of selectors
+        case List.elemIndex sel selector_ids of
+          Nothing -> error $ "Normalize.classopresolution: Selector not found in class' selector list? This should not happen!\nExpression: " ++ pprString expr ++ "\nClass: " ++ show cls ++ "\nSelectors: " ++ show selector_ids
+          -- Get the corresponding argument from the dictionary
+          Just n -> change (selectors!!n)
+      (_, _) -> return expr -- Not applying a variable? Don't touch
+  where
+    -- Compare two type arguments, returning True if they are _not_
+    -- equal
+    tyargs_neq (Type ty1) (Type ty2) = not $ Type.coreEqType ty1 ty2
+    tyargs_neq _ _ = True
+    -- Is this a builtin function / method?
+    is_builtin = elem (Name.getOccString sel) builtinIds
+
+-- Leave all other expressions unchanged
+classopresolution expr = return expr
+-- Perform this transform everywhere
+classopresolutiontop = everywhere ("classopresolution", classopresolution)
 
 --------------------------------
 -- Scrutinee simplification
@@ -428,17 +529,20 @@ casesimpl expr@(Case scrut b ty [(con, bndrs, Var x)]) = return expr
 -- is bound to a new simple selector case statement and for each complex
 -- expression. We do this only for representable types, to prevent loops with
 -- inlinenonrep.
-casesimpl expr@(Case scrut b ty alts) = do
+casesimpl expr@(Case scrut bndr ty alts) | not bndr_used = do
   (bindingss, alts') <- (Monad.liftM unzip) $ mapM doalt alts
   let bindings = concat bindingss
   -- Replace the case with a let with bindings and a case
-  let newlet = mkNonRecLets bindings (Case scrut b ty alts')
+  let newlet = mkNonRecLets bindings (Case scrut bndr ty alts')
   -- If there are no non-wild binders, or this case is already a simple
   -- selector (i.e., a single alt with exactly one binding), already a simple
   -- selector altan no bindings (i.e., no wild binders in the original case),
   -- don't change anything, otherwise, replace the case.
   if null bindings then return expr else change newlet 
   where
+  -- Check if the scrutinee binder is used
+  is_used (_, _, expr) = expr_uses_binders [bndr] expr
+  bndr_used = or $ map is_used alts
   -- Generate a single wild binder, since they are all the same
   wild = MkCore.mkWildBinder
   -- Wilden the binders of one alt, producing a list of bindings as a
@@ -520,7 +624,7 @@ caseremove, caseremovetop :: Transform
 -- Replace a useless case by the value of its single alternative
 caseremove (Case scrut b ty [(con, bndrs, expr)]) | not usesvars = change expr
     -- Find if any of the binders are used by expr
-    where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` bndrs))) expr
+    where usesvars = (not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars (`elem` b:bndrs))) expr
 -- Leave all other expressions unchanged
 caseremove expr = return expr
 -- Perform this transform everywhere
@@ -576,6 +680,12 @@ argprop expr@(App _ _) | is_var fexpr = do
           let newbody = MkCore.mkCoreLams newparams (MkCore.mkCoreApps body oldargs)
           -- Create a new function with the same name but a new body
           newf <- Trans.lift $ mkFunction f newbody
+
+          Trans.lift $ MonadState.modify tsInitStates (\ismap ->
+            let init_state_maybe = Map.lookup f ismap in
+            case init_state_maybe of
+              Nothing -> ismap
+              Just init_state -> Map.insert newf init_state ismap)
           -- Replace the original application with one of the new function to the
           -- new arguments.
           change $ MkCore.mkCoreApps (Var newf) newargs
@@ -716,22 +826,43 @@ simplrestop expr = do
 
 
 -- What transforms to run?
-transforms = [inlinetopleveltop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
+transforms = [inlinedicttop, inlinetopleveltop, classopresolutiontop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
 
--- | Returns the normalized version of the given function.
+-- | Returns the normalized version of the given function, or an error
+-- if it is not a known global binder.
 getNormalized ::
   CoreBndr -- ^ The function to get
   -> TranslatorSession CoreExpr -- The normalized function body
-
-getNormalized bndr = Utils.makeCached bndr tsNormalized $
-  if is_poly (Var bndr)
-    then
-      -- This should really only happen at the top level... TODO: Give
-      -- a different error if this happens down in the recursion.
-      error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
-    else do
-      expr <- getBinding bndr
-      normalizeExpr (show bndr) expr
+getNormalized bndr = do
+  norm <- getNormalized_maybe bndr
+  return $ Maybe.fromMaybe
+    (error $ "Normalize.getNormalized: Unknown or non-representable function requested: " ++ show bndr)
+    norm
+
+-- | Returns the normalized version of the given function, or Nothing
+-- when the binder is not a known global binder or is not normalizeable.
+getNormalized_maybe ::
+  CoreBndr -- ^ The function to get
+  -> TranslatorSession (Maybe CoreExpr) -- The normalized function body
+
+getNormalized_maybe bndr = do
+    expr_maybe <- getGlobalBind bndr
+    normalizeable <- isNormalizeable' bndr
+    if not normalizeable || Maybe.isNothing expr_maybe
+      then
+        -- Binder not normalizeable or not found
+        return Nothing
+      else if is_poly (Var bndr)
+        then
+          -- This should really only happen at the top level... TODO: Give
+          -- a different error if this happens down in the recursion.
+          error $ "\nNormalize.normalizeBind: Function " ++ show bndr ++ " is polymorphic, can't normalize"
+        else do
+          -- Binder found and is monomorphic. Normalize the expression
+          -- and cache the result.
+          normalized <- Utils.makeCached bndr tsNormalized $ 
+            normalizeExpr (show bndr) (Maybe.fromJust expr_maybe)
+          return (Just normalized)
 
 -- | Normalize an expression
 normalizeExpr ::
@@ -747,17 +878,6 @@ normalizeExpr what expr = do
       trace ("\n" ++ what ++ " after normalization:\n\n" ++ showSDoc ( ppr expr')) $ return ()
       return expr'
 
--- | Get the value that is bound to the given binder at top level. Fails when
---   there is no such binding.
-getBinding ::
-  CoreBndr -- ^ The binder to get the expression for
-  -> TranslatorSession CoreExpr -- ^ The value bound to the binder
-
-getBinding bndr = Utils.makeCached bndr tsBindings $
-  -- If the binding isn't in the "cache" (bindings map), then we can't create
-  -- it out of thin air, so return an error.
-  error $ "Normalize.getBinding: Unknown function requested: " ++ show bndr
-
 -- | Split a normalized expression into the argument binders, top level
 --   bindings and the result binder.
 splitNormalized ::