move compiler definition
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
index 91a3f62e2058bf8587342559c04ffefee3c58487..35df5df8329da532e84736cbf67751162acf3ae3 100644 (file)
@@ -41,7 +41,7 @@ from django.db.models.sql import Query as BaseQuery
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
-
+from ldapdb.backends.ldap import compiler
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
@@ -60,9 +60,8 @@ class Constraint(BaseConstraint):
     NOTES: 
     - we redefine this class, because when self.field is None calls
     Field().get_db_prep_lookup(), which short-circuits our LDAP-specific code.
     NOTES: 
     - we redefine this class, because when self.field is None calls
     Field().get_db_prep_lookup(), which short-circuits our LDAP-specific code.
-    - the connection argument defaults to None for django 1.1 compatibility
     """
     """
-    def process(self, lookup_type, value, connection=None):
+    def process(self, lookup_type, value, connection):
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
@@ -84,80 +83,6 @@ class Constraint(BaseConstraint):
 
         return (self.alias, self.col, db_type), params
 
 
         return (self.alias, self.col, db_type), params
 
-class Compiler(object):
-    def __init__(self, query, connection, using):
-        self.query = query
-        self.connection = connection
-        self.using = using
-
-    def results_iter(self):
-        if self.query.select_fields:
-            fields = self.query.select_fields
-        else:
-            fields = self.query.model._meta.fields
-
-        attrlist = [ x.db_column for x in fields if x.db_column ]
-
-        try:
-            vals = self.connection.search_s(
-                self.query.model.base_dn,
-                self.query.model.search_scope,
-                filterstr=self.query._ldap_filter(),
-                attrlist=attrlist,
-            )
-        except ldap.NO_SUCH_OBJECT:
-            return
-
-        # perform sorting
-        if self.query.extra_order_by:
-            ordering = self.query.extra_order_by
-        elif not self.query.default_ordering:
-            ordering = self.query.order_by
-        else:
-            ordering = self.query.order_by or self.query.model._meta.ordering
-        def cmpvals(x, y):
-            for fieldname in ordering:
-                if fieldname.startswith('-'):
-                    fieldname = fieldname[1:]
-                    negate = True
-                else:
-                    negate = False
-                field = self.query.model._meta.get_field(fieldname)
-                attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
-                attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
-                # perform case insensitive comparison
-                if hasattr(attr_x, 'lower'):
-                    attr_x = attr_x.lower()
-                if hasattr(attr_y, 'lower'):
-                    attr_y = attr_y.lower()
-                val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
-                if val:
-                    return val
-            return 0
-        vals = sorted(vals, cmp=cmpvals)
-
-        # process results
-        pos = 0
-        for dn, attrs in vals:
-            # FIXME : This is not optimal, we retrieve more results than we need
-            # but there is probably no other options as we can't perform ordering
-            # server side.
-            if (self.query.low_mark and pos < self.query.low_mark) or \
-               (self.query.high_mark is not None and pos >= self.query.high_mark):
-                pos += 1
-                continue
-            row = []
-            for field in iter(fields):
-                if field.attname == 'dn':
-                    row.append(dn)
-                elif hasattr(field, 'from_ldap'):
-                    row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
-                else:
-                    row.append(None)
-            yield row
-            pos += 1
-
-
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
@@ -180,22 +105,11 @@ class WhereNode(BaseWhereNode):
 
             constraint, lookup_type, y, values = item
             comp = get_lookup_operator(lookup_type)
 
             constraint, lookup_type, y, values = item
             comp = get_lookup_operator(lookup_type)
-            if hasattr(constraint, "col"):
-                # django 1.2
-                column = constraint.col
-                if lookup_type == 'in':
-                    equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
-                    clause = '(|%s)' % ''.join(equal_bits)
-                else:
-                    clause = "(%s%s%s)" % (constraint.col, comp, values)
+            if lookup_type == 'in':
+                equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
+                clause = '(|%s)' % ''.join(equal_bits)
             else:
             else:
-                # django 1.1
-                (table, column, db_type) = constraint
-                equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
-                if len(equal_bits) == 1:
-                    clause = equal_bits[0]
-                else:
-                    clause = '(|%s)' % ''.join(equal_bits)
+                clause = "(%s%s%s)" % (constraint.col, comp, values)
 
             bits.append(clause)
 
 
             bits.append(clause)
 
@@ -227,7 +141,7 @@ class Query(BaseQuery):
         filterstr += sql
         return '(&%s)' % filterstr
 
         filterstr += sql
         return '(&%s)' % filterstr
 
-    def get_count(self, using=None):
+    def get_count(self, using):
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
@@ -248,27 +162,16 @@ class Query(BaseQuery):
         return number
 
     def get_compiler(self, using=None, connection=None):
         return number
 
     def get_compiler(self, using=None, connection=None):
-        return Compiler(self, ldapdb.connection, using)
+        return compiler.SQLCompiler(self, ldapdb.connection, using)
 
     def has_results(self, using):
 
     def has_results(self, using):
-        return self.get_count() != 0
-
-    def results_iter(self):
-        "For django 1.1 compatibility"
-        return self.get_compiler().results_iter()
+        return self.get_count(using) != 0
 
 class QuerySet(BaseQuerySet):
     def __init__(self, model=None, query=None, using=None):
         if not query:
 
 class QuerySet(BaseQuerySet):
     def __init__(self, model=None, query=None, using=None):
         if not query:
-            import inspect
-            spec = inspect.getargspec(BaseQuery.__init__)
-            if len(spec[0]) == 3:
-                # django 1.2
-                query = Query(model, WhereNode)
-            else:
-                # django 1.1
-                query = Query(model, None, WhereNode)
-        super(QuerySet, self).__init__(model=model, query=query)
+            query = Query(model, WhereNode)
+        super(QuerySet, self).__init__(model=model, query=query, using=using)
 
     def delete(self):
         "Bulk deletion."
 
     def delete(self):
         "Bulk deletion."