accept "using" keyword for get_count
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
index eb35061ca41e5eb889c34d0e0b90c4425ee9f737..f433956d52071c5bdac9f71a842dcf565449193f 100644 (file)
@@ -63,6 +63,72 @@ class Constraint(BaseConstraint):
 
         return (self.alias, self.col, db_type), params
 
+class Compiler(object):
+    def __init__(self, query, connection, using):
+        self.query = query
+        self.connection = connection
+        self.using = using
+
+    def results_iter(self):
+        query = self.query
+
+        filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes])
+        sql, params = query.where.as_sql()
+        filterstr += sql
+        filterstr = '(&%s)' % filterstr
+        attrlist = [ x.db_column for x in query.model._meta.local_fields if x.db_column ]
+
+        try:
+            vals = self.connection.search_s(
+                query.model.base_dn,
+                ldap.SCOPE_SUBTREE,
+                filterstr=filterstr,
+                attrlist=attrlist,
+            )
+        except:
+            raise query.model.DoesNotExist
+
+        # perform sorting
+        if query.extra_order_by:
+            ordering = query.extra_order_by
+        elif not query.default_ordering:
+            ordering = query.order_by
+        else:
+            ordering = query.order_by or query.model._meta.ordering
+        def cmpvals(x, y):
+            for fieldname in ordering:
+                if fieldname.startswith('-'):
+                    fieldname = fieldname[1:]
+                    negate = True
+                else:
+                    negate = False
+                field = query.model._meta.get_field(fieldname)
+                attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
+                attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
+                # perform case insensitive comparison
+                if hasattr(attr_x, 'lower'):
+                    attr_x = attr_x.lower()
+                if hasattr(attr_y, 'lower'):
+                    attr_y = attr_y.lower()
+                val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
+                if val:
+                    return val
+            return 0
+        vals = sorted(vals, cmp=cmpvals)
+
+        # process results
+        for dn, attrs in vals:
+            row = []
+            for field in iter(query.model._meta.fields):
+                if field.attname == 'dn':
+                    row.append(dn)
+                elif hasattr(field, 'from_ldap'):
+                    row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
+                else:
+                    row.append(None)
+            yield row
+
+
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
@@ -75,11 +141,11 @@ class WhereNode(BaseWhereNode):
             obj = Constraint(obj.alias, obj.col, obj.field)
         super(WhereNode, self).add((obj, lookup_type, value), connector)
 
-    def as_sql(self, qn=None):
+    def as_sql(self, qn=None, connection=None):
         bits = []
         for item in self.children:
-            if isinstance(item, WhereNode):
-                sql, params = item.as_sql()
+            if hasattr(item, 'as_sql'):
+                sql, params = item.as_sql(qn=qn, connection=connection)
                 bits.append(sql)
                 continue
 
@@ -121,15 +187,14 @@ class Query(BaseQuery):
         super(Query, self).__init__(*args, **kwargs)
         self.connection = ldapdb.connection
 
-    def get_count(self):
-        # FIXME: use all object classes
-        filterstr = '(objectClass=%s)' % self.model.object_classes[0]
+    def get_count(self, using=None):
+        filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
         sql, params = self.where.as_sql()
         filterstr += sql
         filterstr = '(&%s)' % filterstr
 
         try:
-            vals = ldapdb.connection.search_s(
+            vals = self.connection.search_s(
                 self.model.base_dn,
                 ldap.SCOPE_SUBTREE,
                 filterstr=filterstr,
@@ -140,56 +205,12 @@ class Query(BaseQuery):
 
         return len(vals)
 
-    def results_iter(self):
-        # FIXME: use all object classes
-        filterstr = '(objectClass=%s)' % self.model.object_classes[0]
-        sql, params = self.where.as_sql()
-        filterstr += sql
-        filterstr = '(&%s)' % filterstr
-        attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
+    def get_compiler(self, using=None, connection=None):
+        return Compiler(self, ldapdb.connection, using)
 
-        try:
-            vals = ldapdb.connection.search_s(
-                self.model.base_dn,
-                ldap.SCOPE_SUBTREE,
-                filterstr=filterstr,
-                attrlist=attrlist,
-            )
-        except:
-            raise self.model.DoesNotExist
-
-        # perform sorting
-        if self.extra_order_by:
-            ordering = self.extra_order_by
-        elif not self.default_ordering:
-            ordering = self.order_by
-        else:
-            ordering = self.order_by or self.model._meta.ordering
-        def cmpvals(x, y):
-            for field in ordering:
-                if field.startswith('-'):
-                    field = field[1:]
-                    negate = True
-                else:
-                    negate = False
-                attr = self.model._meta.get_field(field).db_column
-                attr_x = x[1].get(attr, '').lower()
-                attr_y = y[1].get(attr, '').lower()
-                val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
-                if val:
-                    return val
-            return 0
-        vals = sorted(vals, cmp=cmpvals)
-
-        # process results
-        for dn, attrs in vals:
-            row = []
-            for field in iter(self.model._meta.fields):
-                if field.attname == 'dn':
-                    row.append(dn)
-                else:
-                    row.append(attrs.get(field.db_column, None))
-            yield row
+    def results_iter(self):
+        "For django 1.1 compatibility"
+        return self.get_compiler().results_iter()
 
 class QuerySet(BaseQuerySet):
     def __init__(self, model=None, query=None, using=None):