don't forget to use offset/limit in get_count()
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
index 49c183e5432f900bde835098917f9197df0cc251..3076e3f3364905bf2a272be6ad31934641cde093 100644 (file)
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # 
 # django-ldapdb
-# Copyright (C) 2009 Bolloré telecom
+# Copyright (C) 2009-2010 Bolloré telecom
 # See AUTHORS file for a full list of contributors.
 # 
 # This program is free software: you can redistribute it and/or modify
@@ -28,13 +28,15 @@ from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as
 
 import ldapdb
 
-def escape_ldap_filter(value):
-    value = str(value)
-    return value.replace('\\', '\\5c') \
-                .replace('*', '\\2a') \
-                .replace('(', '\\28') \
-                .replace(')', '\\29') \
-                .replace('\0', '\\00')
+from ldapdb.models.fields import CharField
+
+def get_lookup_operator(lookup_type):
+    if lookup_type == 'gte':
+        return '>='
+    elif lookup_type == 'lte':
+        return '<='
+    else:
+        return '='
 
 class Constraint(BaseConstraint):
     """
@@ -49,29 +51,92 @@ class Constraint(BaseConstraint):
         # Because of circular imports, we need to import this here.
         from django.db.models.base import ObjectDoesNotExist
 
-        if lookup_type == 'endswith':
-            params = ["*%s" % escape_ldap_filter(value)]
-        elif lookup_type == 'startswith':
-            params = ["%s*" % escape_ldap_filter(value)]
-        elif lookup_type == 'contains':
-            params = ["*%s*" % escape_ldap_filter(value)]
-        elif lookup_type == 'exact':
-            params = [escape_ldap_filter(value)]
-        elif lookup_type == 'in':
-            params = [escape_ldap_filter(v) for v in value]
-        else:
-            raise TypeError("Field has invalid lookup: %s" % lookup_type)
-
         try:
             if self.field:
+                params = self.field.get_db_prep_lookup(lookup_type, value)
                 db_type = self.field.db_type()
             else:
+                params = CharField().get_db_prep_lookup(lookup_type, value)
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
 
         return (self.alias, self.col, db_type), params
 
+class Compiler(object):
+    def __init__(self, query, connection, using):
+        self.query = query
+        self.connection = connection
+        self.using = using
+
+    def results_iter(self):
+        if self.query.select_fields:
+            fields = self.query.select_fields
+        else:
+            fields = self.query.model._meta.fields
+
+        attrlist = [ x.db_column for x in fields if x.db_column ]
+
+        try:
+            vals = self.connection.search_s(
+                self.query.model.base_dn,
+                ldap.SCOPE_SUBTREE,
+                filterstr=self.query._ldap_filter(),
+                attrlist=attrlist,
+            )
+        except ldap.NO_SUCH_OBJECT:
+            return
+
+        # perform sorting
+        if self.query.extra_order_by:
+            ordering = self.query.extra_order_by
+        elif not self.query.default_ordering:
+            ordering = self.query.order_by
+        else:
+            ordering = self.query.order_by or self.query.model._meta.ordering
+        def cmpvals(x, y):
+            for fieldname in ordering:
+                if fieldname.startswith('-'):
+                    fieldname = fieldname[1:]
+                    negate = True
+                else:
+                    negate = False
+                field = self.query.model._meta.get_field(fieldname)
+                attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
+                attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
+                # perform case insensitive comparison
+                if hasattr(attr_x, 'lower'):
+                    attr_x = attr_x.lower()
+                if hasattr(attr_y, 'lower'):
+                    attr_y = attr_y.lower()
+                val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
+                if val:
+                    return val
+            return 0
+        vals = sorted(vals, cmp=cmpvals)
+
+        # process results
+        pos = 0
+        for dn, attrs in vals:
+            # FIXME : This is not optimal, we retrieve more results than we need
+            # but there is probably no other options as we can't perform ordering
+            # server side.
+            if (self.query.low_mark and pos < self.query.low_mark) or \
+               (self.query.high_mark is not None and pos >= self.query.high_mark):
+                pos += 1
+                continue
+            row = []
+            for field in iter(fields):
+                if field.attname == 'dn':
+                    row.append(dn)
+                elif hasattr(field, 'from_ldap'):
+                    row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
+                else:
+                    row.append(None)
+            yield row
+            pos += 1
+
+
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
@@ -84,82 +149,111 @@ class WhereNode(BaseWhereNode):
             obj = Constraint(obj.alias, obj.col, obj.field)
         super(WhereNode, self).add((obj, lookup_type, value), connector)
 
-    def as_sql(self):
+    def as_sql(self, qn=None, connection=None):
         bits = []
         for item in self.children:
-            if isinstance(item, WhereNode):
-                bits.append(item.as_sql())
+            if hasattr(item, 'as_sql'):
+                sql, params = item.as_sql(qn=qn, connection=connection)
+                bits.append(sql)
                 continue
-            if len(item) == 4:
-                # django 1.1
-                (table, column, type), x, y, values = item
-            else:
-                # django 1.0
-                table, column, type, x, y, values = item
-            equal_bits = [ "(%s=%s)" % (column, value) for value in values ]
-            if len(equal_bits) == 1:
-                clause = equal_bits[0]
+
+            constraint, lookup_type, y, values = item
+            comp = get_lookup_operator(lookup_type)
+            if hasattr(constraint, "col"):
+                # django 1.2
+                column = constraint.col
+                if lookup_type == 'in':
+                    equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
+                    clause = '(|%s)' % ''.join(equal_bits)
+                else:
+                    clause = "(%s%s%s)" % (constraint.col, comp, values)
             else:
-                clause = '(|%s)' % ''.join(equal_bits)
+                # django 1.1
+                (table, column, db_type) = constraint
+                equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
+                if len(equal_bits) == 1:
+                    clause = equal_bits[0]
+                else:
+                    clause = '(|%s)' % ''.join(equal_bits)
+
             if self.negated:
                 bits.append('(!%s)' % clause)
             else:
                 bits.append(clause)
         if len(bits) == 1:
-            return bits[0]
+            sql_string = bits[0]
         elif self.connector == AND:
-            return '(&%s)' % ''.join(bits)
+            sql_string = '(&%s)' % ''.join(bits)
         elif self.connector == OR:
-            return '(|%s)' % ''.join(bits)
+            sql_string = '(|%s)' % ''.join(bits)
         else:
             raise Exception("Unhandled WHERE connector: %s" % self.connector)
+        return sql_string, []
 
 class Query(BaseQuery):
-    def results_iter(self):
-        # FIXME: use all object classes
-        filterstr = '(objectClass=%s)' % self.model.object_classes[0]
-        filterstr += self.where.as_sql()
-        filterstr = '(&%s)' % filterstr
-        attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
+    def __init__(self, *args, **kwargs):
+        super(Query, self).__init__(*args, **kwargs)
+        self.connection = ldapdb.connection
+
+    def _ldap_filter(self):
+        filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
+        sql, params = self.where.as_sql()
+        filterstr += sql
+        return '(&%s)' % filterstr
 
+    def get_count(self, using=None):
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
                 ldap.SCOPE_SUBTREE,
-                filterstr=filterstr,
-                attrlist=attrlist,
+                filterstr=self._ldap_filter(),
+                attrlist=[],
             )
-        except:
-            raise self.model.DoesNotExist
+        except ldap.NO_SUCH_OBJECT:
+            return 0
 
-        # perform sorting
-        if self.extra_order_by:
-            ordering = self.extra_order_by
-        elif not self.default_ordering:
-            ordering = self.order_by
-        else:
-            ordering = self.order_by or self.model._meta.ordering
-        def getkey(x):
-            keys = []
-            for k in ordering:
-                attr = self.model._meta.get_field(k).db_column
-                keys.append(x[1].get(attr, '').lower())
-            return keys
-        vals = sorted(vals, key=lambda x: getkey(x))
+        number = len(vals)
 
-        # process results
-        for dn, attrs in vals:
-            row = []
-            for field in iter(self.model._meta.fields):
-                if field.attname == 'dn':
-                    row.append(dn)
-                else:
-                    row.append(attrs.get(field.db_column, None))
-            yield row
+        # apply limit and offset
+        number = max(0, number - self.low_mark)
+        if self.high_mark is not None:
+            number = min(number, self.high_mark - self.low_mark)
+
+        return number
+
+    def get_compiler(self, using=None, connection=None):
+        return Compiler(self, ldapdb.connection, using)
+
+    def results_iter(self):
+        "For django 1.1 compatibility"
+        return self.get_compiler().results_iter()
 
 class QuerySet(BaseQuerySet):
-    def __init__(self, model=None, query=None):
+    def __init__(self, model=None, query=None, using=None):
         if not query:
-            query = Query(model, None, WhereNode)
-        super(QuerySet, self).__init__(model, query)
+            import inspect
+            spec = inspect.getargspec(BaseQuery.__init__)
+            if len(spec[0]) == 3:
+                # django 1.2
+                query = Query(model, WhereNode)
+            else:
+                # django 1.1
+                query = Query(model, None, WhereNode)
+        super(QuerySet, self).__init__(model=model, query=query)
+
+    def delete(self):
+        "Bulk deletion."
+        try:
+            vals = ldapdb.connection.search_s(
+                self.model.base_dn,
+                ldap.SCOPE_SUBTREE,
+                filterstr=self.query._ldap_filter(),
+                attrlist=[],
+            )
+        except ldap.NO_SUCH_OBJECT:
+            return
+
+        # FIXME : there is probably a more efficient way to do this 
+        for dn, attrs in vals:
+            ldapdb.connection.delete_s(dn)