Remove some commented out code.
[matthijs/master-project/cλash.git] / cλash / CLasH / Utils / Core / CoreTools.hs
index d8f4289d28b25253544126ba0424f90f6fc842de..2bb688bb7f0c023a1d9b7986a97eb581f22b808c 100644 (file)
@@ -7,7 +7,9 @@ module CLasH.Utils.Core.CoreTools where
 
 --Standard modules
 import qualified Maybe
 
 --Standard modules
 import qualified Maybe
-import System.IO.Unsafe
+import qualified System.IO.Unsafe
+import qualified Data.Map as Map
+import qualified Data.Accessor.Monad.Trans.State as MonadState
 
 -- GHC API
 import qualified GHC
 
 -- GHC API
 import qualified GHC
@@ -15,39 +17,71 @@ import qualified Type
 import qualified TcType
 import qualified HsExpr
 import qualified HsTypes
 import qualified TcType
 import qualified HsExpr
 import qualified HsTypes
-import qualified HsBinds
 import qualified HscTypes
 import qualified HscTypes
-import qualified RdrName
 import qualified Name
 import qualified Name
-import qualified OccName
-import qualified Type
 import qualified Id
 import qualified TyCon
 import qualified DataCon
 import qualified TysWiredIn
 import qualified Id
 import qualified TyCon
 import qualified DataCon
 import qualified TysWiredIn
-import qualified Bag
 import qualified DynFlags
 import qualified SrcLoc
 import qualified CoreSyn
 import qualified Var
 import qualified IdInfo
 import qualified VarSet
 import qualified DynFlags
 import qualified SrcLoc
 import qualified CoreSyn
 import qualified Var
 import qualified IdInfo
 import qualified VarSet
-import qualified Unique
 import qualified CoreUtils
 import qualified CoreFVs
 import qualified Literal
 import qualified CoreUtils
 import qualified CoreFVs
 import qualified Literal
+import qualified MkCore
+import qualified VarEnv
 
 -- Local imports
 import CLasH.Translator.TranslatorTypes
 import CLasH.Utils.GhcTools
 
 -- Local imports
 import CLasH.Translator.TranslatorTypes
 import CLasH.Utils.GhcTools
+import CLasH.Utils.Core.BinderTools
 import CLasH.Utils.HsTools
 import CLasH.Utils.Pretty
 import CLasH.Utils.HsTools
 import CLasH.Utils.Pretty
+import CLasH.Utils
+import qualified CLasH.Utils.Core.BinderTools as BinderTools
 
 -- | A single binding, used as a shortcut to simplify type signatures.
 type Binding = (CoreSyn.CoreBndr, CoreSyn.CoreExpr)
 
 -- | Evaluate a core Type representing type level int from the tfp
 
 -- | A single binding, used as a shortcut to simplify type signatures.
 type Binding = (CoreSyn.CoreBndr, CoreSyn.CoreExpr)
 
 -- | Evaluate a core Type representing type level int from the tfp
--- library to a real int.
+-- library to a real int. Checks if the type really is a Dec type and
+-- caches the results.
+tfp_to_int :: Type.Type -> TypeSession Int
+tfp_to_int ty = do
+  hscenv <- MonadState.get tsHscEnv
+  let norm_ty = normalize_tfp_int hscenv ty
+  case Type.splitTyConApp_maybe norm_ty of
+    Just (tycon, args) -> do
+      let name = Name.getOccString (TyCon.tyConName tycon)
+      case name of
+        "Dec" ->
+          tfp_to_int' ty
+        otherwise -> do
+          return $ error ("Callin tfp_to_int on non-dec:" ++ (show ty))
+    Nothing -> return $ error ("Callin tfp_to_int on non-dec:" ++ (show ty))
+
+-- | Evaluate a core Type representing type level int from the tfp
+-- library to a real int. Caches the results. Do not use directly, use
+-- tfp_to_int instead.
+tfp_to_int' :: Type.Type -> TypeSession Int
+tfp_to_int' ty = do
+  lens <- MonadState.get tsTfpInts
+  hscenv <- MonadState.get tsHscEnv
+  let norm_ty = normalize_tfp_int hscenv ty
+  let existing_len = Map.lookup (OrdType norm_ty) lens
+  case existing_len of
+    Just len -> return len
+    Nothing -> do
+      let new_len = eval_tfp_int hscenv ty
+      MonadState.modify tsTfpInts (Map.insert (OrdType norm_ty) (new_len))
+      return new_len
+      
+-- | Evaluate a core Type representing type level int from the tfp
+-- library to a real int. Do not use directly, use tfp_to_int instead.
 eval_tfp_int :: HscTypes.HscEnv -> Type.Type -> Int
 eval_tfp_int env ty =
   unsafeRunGhc libdir $ do
 eval_tfp_int :: HscTypes.HscEnv -> Type.Type -> Int
 eval_tfp_int env ty =
   unsafeRunGhc libdir $ do
@@ -67,16 +101,11 @@ eval_tfp_int env ty =
     libdir = DynFlags.topDir dynflags
     dynflags = HscTypes.hsc_dflags env
 
     libdir = DynFlags.topDir dynflags
     dynflags = HscTypes.hsc_dflags env
 
-normalise_tfp_int :: HscTypes.HscEnv -> Type.Type -> Type.Type
-normalise_tfp_int env ty =
-   unsafePerformIO $ do
-     nty <- normaliseType env ty
-     return nty
+normalize_tfp_int :: HscTypes.HscEnv -> Type.Type -> Type.Type
+normalize_tfp_int env ty =
+   System.IO.Unsafe.unsafePerformIO $
+     normalizeType env ty
 
 
--- | Get the width of a SizedWord type
--- sized_word_len :: HscTypes.HscEnv -> Type.Type -> Int
--- sized_word_len env ty = eval_tfp_int env (sized_word_len_ty ty)
-    
 sized_word_len_ty :: Type.Type -> Type.Type
 sized_word_len_ty ty = len
   where
 sized_word_len_ty :: Type.Type -> Type.Type
 sized_word_len_ty ty = len
   where
@@ -85,10 +114,6 @@ sized_word_len_ty ty = len
       Nothing -> error $ "\nCoreTools.sized_word_len_ty: Not a sized word type: " ++ (pprString ty)
     [len]         = args
 
       Nothing -> error $ "\nCoreTools.sized_word_len_ty: Not a sized word type: " ++ (pprString ty)
     [len]         = args
 
--- | Get the width of a SizedInt type
--- sized_int_len :: HscTypes.HscEnv -> Type.Type -> Int
--- sized_int_len env ty = eval_tfp_int env (sized_int_len_ty ty)
-
 sized_int_len_ty :: Type.Type -> Type.Type
 sized_int_len_ty ty = len
   where
 sized_int_len_ty :: Type.Type -> Type.Type
 sized_int_len_ty ty = len
   where
@@ -97,10 +122,6 @@ sized_int_len_ty ty = len
       Nothing -> error $ "\nCoreTools.sized_int_len_ty: Not a sized int type: " ++ (pprString ty)
     [len]         = args
     
       Nothing -> error $ "\nCoreTools.sized_int_len_ty: Not a sized int type: " ++ (pprString ty)
     [len]         = args
     
--- | Get the upperbound of a RangedWord type
--- ranged_word_bound :: HscTypes.HscEnv -> Type.Type -> Int
--- ranged_word_bound env ty = eval_tfp_int env (ranged_word_bound_ty ty)
-    
 ranged_word_bound_ty :: Type.Type -> Type.Type
 ranged_word_bound_ty ty = len
   where
 ranged_word_bound_ty :: Type.Type -> Type.Type
 ranged_word_bound_ty ty = len
   where
@@ -109,26 +130,6 @@ ranged_word_bound_ty ty = len
       Nothing -> error $ "\nCoreTools.ranged_word_bound_ty: Not a sized word type: " ++ (pprString ty)
     [len]         = args
 
       Nothing -> error $ "\nCoreTools.ranged_word_bound_ty: Not a sized word type: " ++ (pprString ty)
     [len]         = args
 
--- | Evaluate a core Type representing type level int from the TypeLevel
--- library to a real int.
--- eval_type_level_int :: Type.Type -> Int
--- eval_type_level_int ty =
---   unsafeRunGhc $ do
---     -- Automatically import modules for any fully qualified identifiers
---     setDynFlag DynFlags.Opt_ImplicitImportQualified
--- 
---     let to_int_name = mkRdrName "Data.TypeLevel.Num.Sets" "toInt"
---     let to_int = SrcLoc.noLoc $ HsExpr.HsVar to_int_name
---     let undef = hsTypedUndef $ coreToHsType ty
---     let app = HsExpr.HsApp (to_int) (undef)
--- 
---     core <- toCore [] app
---     execCore core 
-
--- | Get the length of a FSVec type
--- tfvec_len :: HscTypes.HscEnv -> Type.Type -> Int
--- tfvec_len env ty = eval_tfp_int env (tfvec_len_ty ty)
-
 tfvec_len_ty :: Type.Type -> Type.Type
 tfvec_len_ty ty = len
   where  
 tfvec_len_ty :: Type.Type -> Type.Type
 tfvec_len_ty ty = len
   where  
@@ -151,6 +152,11 @@ is_lam :: CoreSyn.CoreExpr -> Bool
 is_lam (CoreSyn.Lam _ _) = True
 is_lam _ = False
 
 is_lam (CoreSyn.Lam _ _) = True
 is_lam _ = False
 
+-- Is the given core expression a let expression?
+is_let :: CoreSyn.CoreExpr -> Bool
+is_let (CoreSyn.Let _ _) = True
+is_let _ = False
+
 -- Is the given core expression of a function type?
 is_fun :: CoreSyn.CoreExpr -> Bool
 -- Treat Type arguments differently, because exprType is not defined for them.
 -- Is the given core expression of a function type?
 is_fun :: CoreSyn.CoreExpr -> Bool
 -- Treat Type arguments differently, because exprType is not defined for them.
@@ -189,6 +195,10 @@ is_simple _ = False
 has_free_tyvars :: CoreSyn.CoreExpr -> Bool
 has_free_tyvars = not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars Var.isTyVar)
 
 has_free_tyvars :: CoreSyn.CoreExpr -> Bool
 has_free_tyvars = not . VarSet.isEmptyVarSet . (CoreFVs.exprSomeFreeVars Var.isTyVar)
 
+-- Does the given type have any free type vars?
+ty_has_free_tyvars :: Type.Type -> Bool
+ty_has_free_tyvars = not . VarSet.isEmptyVarSet . Type.tyVarsOfType
+
 -- Does the given CoreExpr have any free local vars?
 has_free_vars :: CoreSyn.CoreExpr -> Bool
 has_free_vars = not . VarSet.isEmptyVarSet . CoreFVs.exprFreeVars
 -- Does the given CoreExpr have any free local vars?
 has_free_vars :: CoreSyn.CoreExpr -> Bool
 has_free_vars = not . VarSet.isEmptyVarSet . CoreFVs.exprFreeVars
@@ -220,13 +230,27 @@ get_val_args ty args = drop n args
     -- arguments, to get at the value arguments.
     n = length tyvars + length predtypes
 
     -- arguments, to get at the value arguments.
     n = length tyvars + length predtypes
 
-getLiterals :: CoreSyn.CoreExpr -> [CoreSyn.CoreExpr]
-getLiterals app@(CoreSyn.App _ _) = literals
-  where
-    (CoreSyn.Var f, args) = CoreSyn.collectArgs app
-    literals = filter (is_lit) args
-
-getLiterals lit@(CoreSyn.Lit _) = [lit]
+-- Finds out what literal Integer this expression represents.
+getIntegerLiteral :: CoreSyn.CoreExpr -> TranslatorSession Integer
+getIntegerLiteral expr =
+  case CoreSyn.collectArgs expr of
+    (CoreSyn.Var f, [CoreSyn.Lit (Literal.MachInt integer)]) 
+      | getFullString f == "GHC.Integer.smallInteger" -> return integer
+    (CoreSyn.Var f, [CoreSyn.Lit (Literal.MachInt64 integer)]) 
+      | getFullString f == "GHC.Integer.int64ToInteger" -> return integer
+    (CoreSyn.Var f, [CoreSyn.Lit (Literal.MachWord integer)]) 
+      | getFullString f == "GHC.Integer.wordToInteger" -> return integer
+    (CoreSyn.Var f, [CoreSyn.Lit (Literal.MachWord64 integer)]) 
+      | getFullString f == "GHC.Integer.word64ToInteger" -> return integer
+    -- fromIntegerT returns the integer corresponding to the type of its
+    -- (third) argument. Since it is polymorphic, the type of that
+    -- argument is passed as the first argument, so we can just use that
+    -- one.
+    (CoreSyn.Var f, [CoreSyn.Type dec_ty, dec_dict, CoreSyn.Type num_ty, num_dict, arg]) 
+      | getFullString f == "Types.Data.Num.Ops.fromIntegerT" -> do
+          int <- MonadState.lift tsType $ tfp_to_int dec_ty
+          return $ toInteger int
+    _ -> error $ "CoreTools.getIntegerLiteral: Unsupported Integer literal: " ++ pprString expr
 
 reduceCoreListToHsList :: 
   [HscTypes.CoreModule] -- ^ The modules where parts of the list are hidden
 
 reduceCoreListToHsList :: 
   [HscTypes.CoreModule] -- ^ The modules where parts of the list are hidden
@@ -261,7 +285,7 @@ reduceCoreListToHsList _ _ = return []
 
 -- Is the given var the State data constructor?
 isStateCon :: Var.Var -> Bool
 
 -- Is the given var the State data constructor?
 isStateCon :: Var.Var -> Bool
-isStateCon var = do
+isStateCon var =
   -- See if it is a DataConWrapId (not DataConWorkId, since State is a
   -- newtype).
   case Id.idDetails var of
   -- See if it is a DataConWrapId (not DataConWorkId, since State is a
   -- newtype).
   case Id.idDetails var of
@@ -298,6 +322,29 @@ hasStateType expr = case getType expr of
   Just ty -> isStateType ty
 
 
   Just ty -> isStateType ty
 
 
+-- | Flattens nested lets into a single list of bindings. The expression
+--   passed does not have to be a let expression, if it isn't an empty list of
+--   bindings is returned.
+flattenLets ::
+  CoreSyn.CoreExpr -- ^ The expression to flatten.
+  -> ([Binding], CoreSyn.CoreExpr) -- ^ The bindings and resulting expression.
+flattenLets (CoreSyn.Let binds expr) = 
+  (bindings ++ bindings', expr')
+  where
+    -- Recursively flatten the contained expression
+    (bindings', expr') =flattenLets expr
+    -- Flatten our own bindings to remove the Rec / NonRec constructors
+    bindings = CoreSyn.flattenBinds [binds]
+flattenLets expr = ([], expr)
+
+-- | Create bunch of nested non-recursive let expressions from the given
+-- bindings. The first binding is bound at the highest level (and thus
+-- available in all other bindings).
+mkNonRecLets :: [Binding] -> CoreSyn.CoreExpr -> CoreSyn.CoreExpr
+mkNonRecLets bindings expr = MkCore.mkCoreLets binds expr
+  where
+    binds = map (uncurry CoreSyn.NonRec) bindings
+
 -- | A class of things that (optionally) have a core Type. The type is
 -- optional, since Type expressions don't have a type themselves.
 class TypedThing t where
 -- | A class of things that (optionally) have a core Type. The type is
 -- optional, since Type expressions don't have a type themselves.
 class TypedThing t where
@@ -312,3 +359,105 @@ instance TypedThing CoreSyn.CoreBndr where
 
 instance TypedThing Type.Type where
   getType = return . id
 
 instance TypedThing Type.Type where
   getType = return . id
+
+-- | Generate new uniques for all binders in the given expression.
+-- Does not support making type variables unique, though this could be
+-- supported if required (by passing a CoreSubst.Subst instead of VarEnv to
+-- genUniques' below).
+genUniques :: CoreSyn.CoreExpr -> TranslatorSession CoreSyn.CoreExpr
+genUniques = genUniques' VarEnv.emptyVarEnv
+
+-- | A helper function to generate uniques, that takes a VarEnv containing the
+--   substitutions already performed.
+genUniques' :: VarEnv.VarEnv CoreSyn.CoreBndr -> CoreSyn.CoreExpr -> TranslatorSession CoreSyn.CoreExpr
+genUniques' subst (CoreSyn.Var f) = do
+  -- Replace the binder with its new value, if applicable.
+  let f' = VarEnv.lookupWithDefaultVarEnv subst f f
+  return (CoreSyn.Var f')
+-- Leave literals untouched
+genUniques' subst (CoreSyn.Lit l) = return $ CoreSyn.Lit l
+genUniques' subst (CoreSyn.App f arg) = do
+  -- Only work on subexpressions
+  f' <- genUniques' subst f
+  arg' <- genUniques' subst arg
+  return (CoreSyn.App f' arg')
+-- Don't change type abstractions
+genUniques' subst expr@(CoreSyn.Lam bndr res) | CoreSyn.isTyVar bndr = return expr
+genUniques' subst (CoreSyn.Lam bndr res) = do
+  -- Generate a new unique for the bound variable
+  (subst', bndr') <- genUnique subst bndr
+  res' <- genUniques' subst' res
+  return (CoreSyn.Lam bndr' res')
+genUniques' subst (CoreSyn.Let (CoreSyn.NonRec bndr bound) res) = do
+  -- Make the binders unique
+  (subst', bndr') <- genUnique subst bndr
+  bound' <- genUniques' subst' bound
+  res' <- genUniques' subst' res
+  return $ CoreSyn.Let (CoreSyn.NonRec bndr' bound') res'
+genUniques' subst (CoreSyn.Let (CoreSyn.Rec binds) res) = do
+  -- Make each of the binders unique
+  (subst', bndrs') <- mapAccumLM genUnique subst (map fst binds)
+  bounds' <- mapM (genUniques' subst' . snd) binds
+  res' <- genUniques' subst' res
+  let binds' = zip bndrs' bounds'
+  return $ CoreSyn.Let (CoreSyn.Rec binds') res'
+genUniques' subst (CoreSyn.Case scrut bndr ty alts) = do
+  -- Process the scrutinee with the original substitution, since non of the
+  -- binders bound in the Case statement is in scope in the scrutinee.
+  scrut' <- genUniques' subst scrut
+  -- Generate a new binder for the scrutinee
+  (subst', bndr') <- genUnique subst bndr
+  -- Process each of the alts
+  alts' <- mapM (doalt subst') alts
+  return $ CoreSyn.Case scrut' bndr' ty alts'
+  where
+    doalt subst (con, bndrs, expr) = do
+      (subst', bndrs') <- mapAccumLM genUnique subst bndrs
+      expr' <- genUniques' subst' expr
+      -- Note that we don't return subst', since bndrs are only in scope in
+      -- expr.
+      return (con, bndrs', expr')
+genUniques' subst (CoreSyn.Cast expr coercion) = do
+  expr' <- genUniques' subst expr
+  -- Just process the casted expression
+  return $ CoreSyn.Cast expr' coercion
+genUniques' subst (CoreSyn.Note note expr) = do
+  expr' <- genUniques' subst expr
+  -- Just process the annotated expression
+  return $ CoreSyn.Note note expr'
+-- Leave types untouched
+genUniques' subst expr@(CoreSyn.Type _) = return expr
+
+-- Generate a new unique for the given binder, and extend the given
+-- substitution to reflect this.
+genUnique :: VarEnv.VarEnv CoreSyn.CoreBndr -> CoreSyn.CoreBndr -> TranslatorSession (VarEnv.VarEnv CoreSyn.CoreBndr, CoreSyn.CoreBndr)
+genUnique subst bndr = do
+  bndr' <- BinderTools.cloneVar bndr
+  -- Replace all occurences of the old binder with a reference to the new
+  -- binder.
+  let subst' = VarEnv.extendVarEnv subst bndr bndr'
+  return (subst', bndr')
+
+-- Create a "selector" case that selects the ith field from a datacon
+mkSelCase :: CoreSyn.CoreExpr -> Int -> TranslatorSession CoreSyn.CoreExpr
+mkSelCase scrut i = do
+  let scrut_ty = CoreUtils.exprType scrut
+  case Type.splitTyConApp_maybe scrut_ty of
+    -- The scrutinee should have a type constructor. We keep the type
+    -- arguments around so we can instantiate the field types below
+    Just (tycon, tyargs) -> case TyCon.tyConDataCons tycon of
+      -- The scrutinee type should have a single dataconstructor,
+      -- otherwise we can't construct a valid selector case.
+      [datacon] -> do
+        let field_tys = DataCon.dataConInstOrigArgTys datacon tyargs
+        -- Create a list of wild binders for the fields we don't want
+        let wildbndrs = map MkCore.mkWildBinder field_tys
+        -- Create a single binder for the field we want
+        sel_bndr <- mkInternalVar "sel" (field_tys!!i)
+        -- Create a wild binder for the scrutinee
+        let scrut_bndr = MkCore.mkWildBinder scrut_ty
+        -- Create the case expression
+        let binders = take i wildbndrs ++ [sel_bndr] ++ drop (i+1) wildbndrs
+        return $ CoreSyn.Case scrut scrut_bndr scrut_ty [(CoreSyn.DataAlt datacon, binders, CoreSyn.Var sel_bndr)]
+      dcs -> error $ "CoreTools.mkSelCase: Scrutinee type must have exactly one datacon. Extracting element " ++ (show i) ++ " from '" ++ pprString scrut ++ "' Datacons: " ++ (show dcs) ++ " Type: " ++ (pprString scrut_ty)
+    Nothing -> error $ "CoreTools.mkSelCase: Creating extractor case, but scrutinee has no tycon? Extracting element " ++ (show i) ++ " from '" ++ pprString scrut ++ "'" ++ " Type: " ++ (pprString scrut_ty)