move compiler definition
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
index ab5f19ec727b574cc845d44995d1b4371611ead7..35df5df8329da532e84736cbf67751162acf3ae3 100644 (file)
@@ -41,7 +41,7 @@ from django.db.models.sql import Query as BaseQuery
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
-
+from ldapdb.backends.ldap import compiler
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
@@ -56,8 +56,12 @@ class Constraint(BaseConstraint):
     """
     An object that can be passed to WhereNode.add() and knows how to
     pre-process itself prior to including in the WhereNode.
     """
     An object that can be passed to WhereNode.add() and knows how to
     pre-process itself prior to including in the WhereNode.
+
+    NOTES: 
+    - we redefine this class, because when self.field is None calls
+    Field().get_db_prep_lookup(), which short-circuits our LDAP-specific code.
     """
     """
-    def process(self, lookup_type, value):
+    def process(self, lookup_type, value, connection):
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
@@ -67,90 +71,18 @@ class Constraint(BaseConstraint):
 
         try:
             if self.field:
 
         try:
             if self.field:
-                params = self.field.get_db_prep_lookup(lookup_type, value)
+                params = self.field.get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
                 db_type = self.field.db_type()
             else:
                 db_type = self.field.db_type()
             else:
-                params = CharField().get_db_prep_lookup(lookup_type, value)
+                params = CharField().get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
 
         return (self.alias, self.col, db_type), params
 
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
 
         return (self.alias, self.col, db_type), params
 
-class Compiler(object):
-    def __init__(self, query, connection, using):
-        self.query = query
-        self.connection = connection
-        self.using = using
-
-    def results_iter(self):
-        if self.query.select_fields:
-            fields = self.query.select_fields
-        else:
-            fields = self.query.model._meta.fields
-
-        attrlist = [ x.db_column for x in fields if x.db_column ]
-
-        try:
-            vals = self.connection.search_s(
-                self.query.model.base_dn,
-                self.query.model.search_scope,
-                filterstr=self.query._ldap_filter(),
-                attrlist=attrlist,
-            )
-        except ldap.NO_SUCH_OBJECT:
-            return
-
-        # perform sorting
-        if self.query.extra_order_by:
-            ordering = self.query.extra_order_by
-        elif not self.query.default_ordering:
-            ordering = self.query.order_by
-        else:
-            ordering = self.query.order_by or self.query.model._meta.ordering
-        def cmpvals(x, y):
-            for fieldname in ordering:
-                if fieldname.startswith('-'):
-                    fieldname = fieldname[1:]
-                    negate = True
-                else:
-                    negate = False
-                field = self.query.model._meta.get_field(fieldname)
-                attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
-                attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
-                # perform case insensitive comparison
-                if hasattr(attr_x, 'lower'):
-                    attr_x = attr_x.lower()
-                if hasattr(attr_y, 'lower'):
-                    attr_y = attr_y.lower()
-                val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
-                if val:
-                    return val
-            return 0
-        vals = sorted(vals, cmp=cmpvals)
-
-        # process results
-        pos = 0
-        for dn, attrs in vals:
-            # FIXME : This is not optimal, we retrieve more results than we need
-            # but there is probably no other options as we can't perform ordering
-            # server side.
-            if (self.query.low_mark and pos < self.query.low_mark) or \
-               (self.query.high_mark is not None and pos >= self.query.high_mark):
-                pos += 1
-                continue
-            row = []
-            for field in iter(fields):
-                if field.attname == 'dn':
-                    row.append(dn)
-                elif hasattr(field, 'from_ldap'):
-                    row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
-                else:
-                    row.append(None)
-            yield row
-            pos += 1
-
-
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
@@ -173,27 +105,17 @@ class WhereNode(BaseWhereNode):
 
             constraint, lookup_type, y, values = item
             comp = get_lookup_operator(lookup_type)
 
             constraint, lookup_type, y, values = item
             comp = get_lookup_operator(lookup_type)
-            if hasattr(constraint, "col"):
-                # django 1.2
-                column = constraint.col
-                if lookup_type == 'in':
-                    equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
-                    clause = '(|%s)' % ''.join(equal_bits)
-                else:
-                    clause = "(%s%s%s)" % (constraint.col, comp, values)
+            if lookup_type == 'in':
+                equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
+                clause = '(|%s)' % ''.join(equal_bits)
             else:
             else:
-                # django 1.1
-                (table, column, db_type) = constraint
-                equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
-                if len(equal_bits) == 1:
-                    clause = equal_bits[0]
-                else:
-                    clause = '(|%s)' % ''.join(equal_bits)
+                clause = "(%s%s%s)" % (constraint.col, comp, values)
+
+            bits.append(clause)
+
+        if not len(bits):
+            return '', []
 
 
-            if self.negated:
-                bits.append('(!%s)' % clause)
-            else:
-                bits.append(clause)
         if len(bits) == 1:
             sql_string = bits[0]
         elif self.connector == AND:
         if len(bits) == 1:
             sql_string = bits[0]
         elif self.connector == AND:
@@ -202,6 +124,10 @@ class WhereNode(BaseWhereNode):
             sql_string = '(|%s)' % ''.join(bits)
         else:
             raise Exception("Unhandled WHERE connector: %s" % self.connector)
             sql_string = '(|%s)' % ''.join(bits)
         else:
             raise Exception("Unhandled WHERE connector: %s" % self.connector)
+
+        if self.negated:
+            sql_string = ('(!%s)' % sql_string)
+
         return sql_string, []
 
 class Query(BaseQuery):
         return sql_string, []
 
 class Query(BaseQuery):
@@ -215,7 +141,7 @@ class Query(BaseQuery):
         filterstr += sql
         return '(&%s)' % filterstr
 
         filterstr += sql
         return '(&%s)' % filterstr
 
-    def get_count(self, using=None):
+    def get_count(self, using):
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
@@ -236,27 +162,16 @@ class Query(BaseQuery):
         return number
 
     def get_compiler(self, using=None, connection=None):
         return number
 
     def get_compiler(self, using=None, connection=None):
-        return Compiler(self, ldapdb.connection, using)
+        return compiler.SQLCompiler(self, ldapdb.connection, using)
 
     def has_results(self, using):
 
     def has_results(self, using):
-        return self.get_count() != 0
-
-    def results_iter(self):
-        "For django 1.1 compatibility"
-        return self.get_compiler().results_iter()
+        return self.get_count(using) != 0
 
 class QuerySet(BaseQuerySet):
     def __init__(self, model=None, query=None, using=None):
         if not query:
 
 class QuerySet(BaseQuerySet):
     def __init__(self, model=None, query=None, using=None):
         if not query:
-            import inspect
-            spec = inspect.getargspec(BaseQuery.__init__)
-            if len(spec[0]) == 3:
-                # django 1.2
-                query = Query(model, WhereNode)
-            else:
-                # django 1.1
-                query = Query(model, None, WhereNode)
-        super(QuerySet, self).__init__(model=model, query=query)
+            query = Query(model, WhereNode)
+        super(QuerySet, self).__init__(model=model, query=query, using=using)
 
     def delete(self):
         "Bulk deletion."
 
     def delete(self):
         "Bulk deletion."