move compiler definition
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
index 877ea9e40ee74c0b323361e39ca2ea893ba79f4f..35df5df8329da532e84736cbf67751162acf3ae3 100644 (file)
@@ -1,21 +1,35 @@
 # -*- coding: utf-8 -*-
 # 
 # django-ldapdb
 # -*- coding: utf-8 -*-
 # 
 # django-ldapdb
-# Copyright (C) 2009 BollorĂ© telecom
+# Copyright (c) 2009-2010, BollorĂ© telecom
+# All rights reserved.
+# 
 # See AUTHORS file for a full list of contributors.
 # 
 # See AUTHORS file for a full list of contributors.
 # 
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+# 
+#     1. Redistributions of source code must retain the above copyright notice, 
+#        this list of conditions and the following disclaimer.
+#     
+#     2. Redistributions in binary form must reproduce the above copyright 
+#        notice, this list of conditions and the following disclaimer in the
+#        documentation and/or other materials provided with the distribution.
 # 
 # 
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
+#     3. Neither the name of BollorĂ© telecom nor the names of its contributors
+#        may be used to endorse or promote products derived from this software
+#        without specific prior written permission.
 # 
 # 
-# You should have received a copy of the GNU General Public License
-# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 #
 
 from copy import deepcopy
 #
 
 from copy import deepcopy
@@ -27,21 +41,27 @@ from django.db.models.sql import Query as BaseQuery
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
+from ldapdb.backends.ldap import compiler
+from ldapdb.models.fields import CharField
 
 
-def escape_ldap_filter(value):
-    value = str(value)
-    return value.replace('\\', '\\5c') \
-                .replace('*', '\\2a') \
-                .replace('(', '\\28') \
-                .replace(')', '\\29') \
-                .replace('\0', '\\00')
+def get_lookup_operator(lookup_type):
+    if lookup_type == 'gte':
+        return '>='
+    elif lookup_type == 'lte':
+        return '<='
+    else:
+        return '='
 
 class Constraint(BaseConstraint):
     """
     An object that can be passed to WhereNode.add() and knows how to
     pre-process itself prior to including in the WhereNode.
 
 class Constraint(BaseConstraint):
     """
     An object that can be passed to WhereNode.add() and knows how to
     pre-process itself prior to including in the WhereNode.
+
+    NOTES: 
+    - we redefine this class, because when self.field is None calls
+    Field().get_db_prep_lookup(), which short-circuits our LDAP-specific code.
     """
     """
-    def process(self, lookup_type, value):
+    def process(self, lookup_type, value, connection):
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
@@ -49,21 +69,14 @@ class Constraint(BaseConstraint):
         # Because of circular imports, we need to import this here.
         from django.db.models.base import ObjectDoesNotExist
 
         # Because of circular imports, we need to import this here.
         from django.db.models.base import ObjectDoesNotExist
 
-        if lookup_type == 'endswith':
-            params = ["*%s" % escape_ldap_filter(value)]
-        elif lookup_type == 'startswith':
-            params = ["%s*" % escape_ldap_filter(value)]
-        elif lookup_type == 'exact':
-            params = [escape_ldap_filter(value)]
-        elif lookup_type == 'in':
-            params = [escape_ldap_filter(v) for v in value]
-        else:
-            raise TypeError("Field has invalid lookup: %s" % lookup_type)
-
         try:
             if self.field:
         try:
             if self.field:
+                params = self.field.get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
                 db_type = self.field.db_type()
             else:
                 db_type = self.field.db_type()
             else:
+                params = CharField().get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
@@ -82,82 +95,97 @@ class WhereNode(BaseWhereNode):
             obj = Constraint(obj.alias, obj.col, obj.field)
         super(WhereNode, self).add((obj, lookup_type, value), connector)
 
             obj = Constraint(obj.alias, obj.col, obj.field)
         super(WhereNode, self).add((obj, lookup_type, value), connector)
 
-    def as_sql(self):
+    def as_sql(self, qn=None, connection=None):
         bits = []
         for item in self.children:
         bits = []
         for item in self.children:
-            if isinstance(item, WhereNode):
-                bits.append(item.as_sql())
+            if hasattr(item, 'as_sql'):
+                sql, params = item.as_sql(qn=qn, connection=connection)
+                bits.append(sql)
                 continue
                 continue
-            if len(item) == 4:
-                # django 1.1
-                (table, column, type), x, y, values = item
-            else:
-                # django 1.0
-                table, column, type, x, y, values = item
-            equal_bits = [ "(%s=%s)" % (column, value) for value in values ]
-            if len(equal_bits) == 1:
-                clause = equal_bits[0]
-            else:
+
+            constraint, lookup_type, y, values = item
+            comp = get_lookup_operator(lookup_type)
+            if lookup_type == 'in':
+                equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
                 clause = '(|%s)' % ''.join(equal_bits)
                 clause = '(|%s)' % ''.join(equal_bits)
-            if self.negated:
-                bits.append('(!%s)' % clause)
             else:
             else:
-                bits.append(clause)
+                clause = "(%s%s%s)" % (constraint.col, comp, values)
+
+            bits.append(clause)
+
+        if not len(bits):
+            return '', []
+
         if len(bits) == 1:
         if len(bits) == 1:
-            return bits[0]
+            sql_string = bits[0]
         elif self.connector == AND:
         elif self.connector == AND:
-            return '(&%s)' % ''.join(bits)
+            sql_string = '(&%s)' % ''.join(bits)
         elif self.connector == OR:
         elif self.connector == OR:
-            return '(|%s)' % ''.join(bits)
+            sql_string = '(|%s)' % ''.join(bits)
         else:
             raise Exception("Unhandled WHERE connector: %s" % self.connector)
 
         else:
             raise Exception("Unhandled WHERE connector: %s" % self.connector)
 
+        if self.negated:
+            sql_string = ('(!%s)' % sql_string)
+
+        return sql_string, []
+
 class Query(BaseQuery):
 class Query(BaseQuery):
-    def results_iter(self):
-        # FIXME: use all object classes
-        filterstr = '(objectClass=%s)' % self.model.object_classes[0]
-        filterstr += self.where.as_sql()
-        filterstr = '(&%s)' % filterstr
-        attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
+    def __init__(self, *args, **kwargs):
+        super(Query, self).__init__(*args, **kwargs)
+        self.connection = ldapdb.connection
+
+    def _ldap_filter(self):
+        filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
+        sql, params = self.where.as_sql()
+        filterstr += sql
+        return '(&%s)' % filterstr
 
 
+    def get_count(self, using):
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
-                ldap.SCOPE_SUBTREE,
-                filterstr=filterstr,
-                attrlist=attrlist,
+                self.model.search_scope,
+                filterstr=self._ldap_filter(),
+                attrlist=[],
             )
             )
-        except:
-            raise self.model.DoesNotExist
-
-        # perform sorting
-        if self.extra_order_by:
-            ordering = self.extra_order_by
-        elif not self.default_ordering:
-            ordering = self.order_by
-        else:
-            ordering = self.order_by or self.model._meta.ordering
-        def getkey(x):
-            keys = []
-            for k in ordering:
-                attr = self.model._meta.get_field(k).db_column
-                keys.append(x[1].get(attr, '').lower())
-            return keys
-        vals = sorted(vals, key=lambda x: getkey(x))
-
-        # process results
-        for dn, attrs in vals:
-            row = []
-            for field in iter(self.model._meta.fields):
-                if field.attname == 'dn':
-                    row.append(dn)
-                else:
-                    row.append(attrs.get(field.db_column, None))
-            yield row
+        except ldap.NO_SUCH_OBJECT:
+            return 0
+
+        number = len(vals)
+
+        # apply limit and offset
+        number = max(0, number - self.low_mark)
+        if self.high_mark is not None:
+            number = min(number, self.high_mark - self.low_mark)
+
+        return number
+
+    def get_compiler(self, using=None, connection=None):
+        return compiler.SQLCompiler(self, ldapdb.connection, using)
+
+    def has_results(self, using):
+        return self.get_count(using) != 0
 
 class QuerySet(BaseQuerySet):
 
 class QuerySet(BaseQuerySet):
-    def __init__(self, model=None, query=None):
+    def __init__(self, model=None, query=None, using=None):
         if not query:
         if not query:
-            query = Query(model, None, WhereNode)
-        super(QuerySet, self).__init__(model, query)
+            query = Query(model, WhereNode)
+        super(QuerySet, self).__init__(model=model, query=query, using=using)
+
+    def delete(self):
+        "Bulk deletion."
+        try:
+            vals = ldapdb.connection.search_s(
+                self.model.base_dn,
+                self.model.search_scope,
+                filterstr=self.query._ldap_filter(),
+                attrlist=[],
+            )
+        except ldap.NO_SUCH_OBJECT:
+            return
+
+        # FIXME : there is probably a more efficient way to do this 
+        for dn, attrs in vals:
+            ldapdb.connection.delete_s(dn)