allow order_by using a '-' prefix
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
index a8c614f7ccfe6a2ee487ac35b38bb474951efa0c..b35c12311d29bead6b88b0ebba43f74b4ee681e1 100644 (file)
@@ -81,20 +81,26 @@ class WhereNode(BaseWhereNode):
             if isinstance(item, WhereNode):
                 bits.append(item.as_sql())
                 continue
-            constraint, x, y, values = item
+
+            constraint, lookup_type, y, values = item
+            comp = get_lookup_operator(lookup_type)
             if hasattr(constraint, "col"):
                 # django 1.2
-                comp = get_lookup_operator(constraint.lookup_type)
-                clause = "(%s%s%s)" % (constraint.col, comp, values)
+                column = constraint.col
+                if lookup_type == 'in':
+                    equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
+                    clause = '(|%s)' % ''.join(equal_bits)
+                else:
+                    clause = "(%s%s%s)" % (constraint.col, comp, values)
             else:
                 # django 1.1
                 (table, column, db_type) = constraint
-                comp = get_lookup_operator(x)
                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
                 if len(equal_bits) == 1:
                     clause = equal_bits[0]
                 else:
                     clause = '(|%s)' % ''.join(equal_bits)
+
             if self.negated:
                 bits.append('(!%s)' % clause)
             else:
@@ -133,13 +139,21 @@ class Query(BaseQuery):
             ordering = self.order_by
         else:
             ordering = self.order_by or self.model._meta.ordering
-        def getkey(x):
-            keys = []
-            for k in ordering:
-                attr = self.model._meta.get_field(k).db_column
-                keys.append(x[1].get(attr, '').lower())
-            return keys
-        vals = sorted(vals, key=lambda x: getkey(x))
+        def cmpvals(x, y):
+            for field in ordering:
+                if field.startswith('-'):
+                    field = field[1:]
+                    negate = True
+                else:
+                    negate = False
+                attr = self.model._meta.get_field(field).db_column
+                attr_x = x[1].get(attr, '').lower()
+                attr_y = y[1].get(attr, '').lower()
+                val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
+                if val:
+                    return val
+            return 0
+        vals = sorted(vals, cmp=cmpvals)
 
         # process results
         for dn, attrs in vals:
@@ -152,8 +166,15 @@ class Query(BaseQuery):
             yield row
 
 class QuerySet(BaseQuerySet):
-    def __init__(self, model=None, query=None):
+    def __init__(self, model=None, query=None, using=None):
         if not query:
-            query = Query(model, None, WhereNode)
+            import inspect
+            spec = inspect.getargspec(Query.__init__)
+            if len(spec[0]) == 3:
+                # django 1.2
+                query = Query(model, WhereNode)
+            else:
+                # django 1.1
+                query = Query(model, None, WhereNode)
         super(QuerySet, self).__init__(model, query)