Add class operation resolution transformation.
authorMatthijs Kooijman <matthijs@stdin.nl>
Mon, 8 Mar 2010 10:35:52 +0000 (11:35 +0100)
committerMatthijs Kooijman <matthijs@stdin.nl>
Mon, 8 Mar 2010 10:39:01 +0000 (11:39 +0100)
This transformation enables user-defined type classes.

cλash/CLasH/Normalize.hs

index 4c04e8557553d4044944749b0b49a85c699e779a..86ade5b21728f16d2422090328147ed642332ce8 100644 (file)
@@ -25,6 +25,7 @@ import qualified Id
 import qualified Var
 import qualified VarSet
 import qualified CoreFVs
+import qualified Class
 import qualified MkCore
 import Outputable ( showSDoc, ppr, nest )
 
@@ -380,6 +381,60 @@ inlinedict expr@(Var f) | Id.isDictId f = do
 inlinedict expr = return expr
 inlinedicttop = everywhere ("inlinedict", inlinedict)
 
+--------------------------------
+-- ClassOp resolution
+--------------------------------
+-- Resolves any class operation to the actual operation whenever
+-- possible. Class methods (as well as parent dictionary selectors) are
+-- special "functions" that take a type and a dictionary and evaluate to
+-- the corresponding method. A dictionary is nothing more than a
+-- special dataconstructor applied to the type the dictionary is for,
+-- each of the superclasses and all of the class method definitions for
+-- that particular type. Since dictionaries all always inlined (top
+-- levels dictionaries are inlined by inlinedict, local dictionaries are
+-- inlined by inlinenonrep), we will eventually have something like:
+--
+--   baz
+--     @ CLasH.HardwareTypes.Bit
+--     (D:Baz @ CLasH.HardwareTypes.Bit bitbaz)
+--
+-- Here, baz is the method selector for the baz method, while
+-- D:Baz is the dictionary constructor for the Baz and bitbaz is the baz
+-- method defined in the Baz Bit instance declaration.
+--
+-- To resolve this, we can look at the ClassOp IdInfo from the baz Id,
+-- which contains the Class it is defined for. From the Class, we can
+-- get a list of all selectors (both parent class selectors as well as
+-- method selectors). Since the arguments to D:Baz (after the type
+-- argument) correspond exactly to this list, we then look up baz in
+-- that list and replace the entire expression by the corresponding 
+-- argument to D:Baz.
+classopresolution, classopresolutiontop :: Transform
+classopresolution expr@(App (App (Var sel) ty) dict) =
+  case Id.isClassOpId_maybe sel of
+    -- Not a class op selector
+    Nothing -> return expr
+    Just cls -> case collectArgs dict of
+      (_, []) -> return expr -- Dict is not an application (e.g., not inlined yet)
+      (dictdc, (ty':selectors)) | tyargs_neq ty ty' -> error $ "Applying class selector to dictionary without matching type?\n" ++ pprString expr
+                                | otherwise ->
+        let selector_ids = Class.classSelIds cls in
+        -- Find the selector used in the class' list of selectors
+        case List.elemIndex sel selector_ids of
+          Nothing -> error $ "Selector not found in class' selector list? This should not happen!\nExpression: " ++ pprString expr ++ "\nClass: " ++ show cls ++ "\nSelectors: " ++ show selector_ids
+          -- Get the corresponding argument from the dictionary
+          Just n -> change (selectors!!n)
+  where
+    -- Compare two type arguments, returning True if they are _not_
+    -- equal
+    tyargs_neq (Type ty1) (Type ty2) = not $ Type.coreEqType ty1 ty2
+    tyargs_neq _ _ = True
+
+-- Leave all other expressions unchanged
+classopresolution expr = return expr
+-- Perform this transform everywhere
+classopresolutiontop = everywhere ("classopresolution", classopresolution)
+
 --------------------------------
 -- Scrutinee simplification
 --------------------------------
@@ -736,7 +791,7 @@ simplrestop expr = do
 
 
 -- What transforms to run?
-transforms = [inlinedicttop, inlinetopleveltop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
+transforms = [inlinedicttop, inlinetopleveltop, classopresolutiontop, argproptop, funextracttop, etatop, betatop, castproptop, letremovesimpletop, letderectop, letremovetop, letsimpltop, letflattop, scrutsimpltop, scrutbndrremovetop, casesimpltop, caseremovetop, inlinenonreptop, appsimpltop, letremoveunusedtop, castsimpltop, lambdasimpltop, simplrestop]
 
 -- | Returns the normalized version of the given function, or an error
 -- if it is not a known global binder.