move compiler definition
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
index 4f0273cc2ed6478f7fcd78897251cd1732bb932d..35df5df8329da532e84736cbf67751162acf3ae3 100644 (file)
@@ -1,21 +1,35 @@
 # -*- coding: utf-8 -*-
 # 
 # django-ldapdb
 # -*- coding: utf-8 -*-
 # 
 # django-ldapdb
-# Copyright (C) 2009-2010 BollorĂ© telecom
+# Copyright (c) 2009-2010, BollorĂ© telecom
+# All rights reserved.
+# 
 # See AUTHORS file for a full list of contributors.
 # 
 # See AUTHORS file for a full list of contributors.
 # 
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+# 
+#     1. Redistributions of source code must retain the above copyright notice, 
+#        this list of conditions and the following disclaimer.
+#     
+#     2. Redistributions in binary form must reproduce the above copyright 
+#        notice, this list of conditions and the following disclaimer in the
+#        documentation and/or other materials provided with the distribution.
 # 
 # 
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
+#     3. Neither the name of BollorĂ© telecom nor the names of its contributors
+#        may be used to endorse or promote products derived from this software
+#        without specific prior written permission.
 # 
 # 
-# You should have received a copy of the GNU General Public License
-# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 #
 
 from copy import deepcopy
 #
 
 from copy import deepcopy
@@ -27,7 +41,7 @@ from django.db.models.sql import Query as BaseQuery
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
-
+from ldapdb.backends.ldap import compiler
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
@@ -42,8 +56,12 @@ class Constraint(BaseConstraint):
     """
     An object that can be passed to WhereNode.add() and knows how to
     pre-process itself prior to including in the WhereNode.
     """
     An object that can be passed to WhereNode.add() and knows how to
     pre-process itself prior to including in the WhereNode.
+
+    NOTES: 
+    - we redefine this class, because when self.field is None calls
+    Field().get_db_prep_lookup(), which short-circuits our LDAP-specific code.
     """
     """
-    def process(self, lookup_type, value):
+    def process(self, lookup_type, value, connection):
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
@@ -53,74 +71,18 @@ class Constraint(BaseConstraint):
 
         try:
             if self.field:
 
         try:
             if self.field:
-                params = self.field.get_db_prep_lookup(lookup_type, value)
+                params = self.field.get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
                 db_type = self.field.db_type()
             else:
                 db_type = self.field.db_type()
             else:
-                params = CharField().get_db_prep_lookup(lookup_type, value)
+                params = CharField().get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
 
         return (self.alias, self.col, db_type), params
 
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
 
         return (self.alias, self.col, db_type), params
 
-class Compiler(object):
-    def __init__(self, query, connection, using):
-        self.query = query
-        self.connection = connection
-        self.using = using
-
-    def results_iter(self):
-        query = self.query
-        attrlist = [ x.db_column for x in query.model._meta.local_fields if x.db_column ]
-
-        vals = self.connection.search_s(
-            query.model.base_dn,
-            ldap.SCOPE_SUBTREE,
-            filterstr=query._ldap_filter(),
-            attrlist=attrlist,
-        )
-
-        # perform sorting
-        if query.extra_order_by:
-            ordering = query.extra_order_by
-        elif not query.default_ordering:
-            ordering = query.order_by
-        else:
-            ordering = query.order_by or query.model._meta.ordering
-        def cmpvals(x, y):
-            for fieldname in ordering:
-                if fieldname.startswith('-'):
-                    fieldname = fieldname[1:]
-                    negate = True
-                else:
-                    negate = False
-                field = query.model._meta.get_field(fieldname)
-                attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
-                attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
-                # perform case insensitive comparison
-                if hasattr(attr_x, 'lower'):
-                    attr_x = attr_x.lower()
-                if hasattr(attr_y, 'lower'):
-                    attr_y = attr_y.lower()
-                val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
-                if val:
-                    return val
-            return 0
-        vals = sorted(vals, cmp=cmpvals)
-
-        # process results
-        for dn, attrs in vals:
-            row = []
-            for field in iter(query.model._meta.fields):
-                if field.attname == 'dn':
-                    row.append(dn)
-                elif hasattr(field, 'from_ldap'):
-                    row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
-                else:
-                    row.append(None)
-            yield row
-
-
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
@@ -143,27 +105,17 @@ class WhereNode(BaseWhereNode):
 
             constraint, lookup_type, y, values = item
             comp = get_lookup_operator(lookup_type)
 
             constraint, lookup_type, y, values = item
             comp = get_lookup_operator(lookup_type)
-            if hasattr(constraint, "col"):
-                # django 1.2
-                column = constraint.col
-                if lookup_type == 'in':
-                    equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
-                    clause = '(|%s)' % ''.join(equal_bits)
-                else:
-                    clause = "(%s%s%s)" % (constraint.col, comp, values)
+            if lookup_type == 'in':
+                equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
+                clause = '(|%s)' % ''.join(equal_bits)
             else:
             else:
-                # django 1.1
-                (table, column, db_type) = constraint
-                equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
-                if len(equal_bits) == 1:
-                    clause = equal_bits[0]
-                else:
-                    clause = '(|%s)' % ''.join(equal_bits)
-
-            if self.negated:
-                bits.append('(!%s)' % clause)
-            else:
-                bits.append(clause)
+                clause = "(%s%s%s)" % (constraint.col, comp, values)
+
+            bits.append(clause)
+
+        if not len(bits):
+            return '', []
+
         if len(bits) == 1:
             sql_string = bits[0]
         elif self.connector == AND:
         if len(bits) == 1:
             sql_string = bits[0]
         elif self.connector == AND:
@@ -172,6 +124,10 @@ class WhereNode(BaseWhereNode):
             sql_string = '(|%s)' % ''.join(bits)
         else:
             raise Exception("Unhandled WHERE connector: %s" % self.connector)
             sql_string = '(|%s)' % ''.join(bits)
         else:
             raise Exception("Unhandled WHERE connector: %s" % self.connector)
+
+        if self.negated:
+            sql_string = ('(!%s)' % sql_string)
+
         return sql_string, []
 
 class Query(BaseQuery):
         return sql_string, []
 
 class Query(BaseQuery):
@@ -185,43 +141,50 @@ class Query(BaseQuery):
         filterstr += sql
         return '(&%s)' % filterstr
 
         filterstr += sql
         return '(&%s)' % filterstr
 
-    def get_count(self, using=None):
-        vals = ldapdb.connection.search_s(
-            self.model.base_dn,
-            ldap.SCOPE_SUBTREE,
-            filterstr=self._ldap_filter(),
-            attrlist=[],
-        )
-        return len(vals)
+    def get_count(self, using):
+        try:
+            vals = ldapdb.connection.search_s(
+                self.model.base_dn,
+                self.model.search_scope,
+                filterstr=self._ldap_filter(),
+                attrlist=[],
+            )
+        except ldap.NO_SUCH_OBJECT:
+            return 0
+
+        number = len(vals)
+
+        # apply limit and offset
+        number = max(0, number - self.low_mark)
+        if self.high_mark is not None:
+            number = min(number, self.high_mark - self.low_mark)
+
+        return number
 
     def get_compiler(self, using=None, connection=None):
 
     def get_compiler(self, using=None, connection=None):
-        return Compiler(self, ldapdb.connection, using)
+        return compiler.SQLCompiler(self, ldapdb.connection, using)
 
 
-    def results_iter(self):
-        "For django 1.1 compatibility"
-        return self.get_compiler().results_iter()
+    def has_results(self, using):
+        return self.get_count(using) != 0
 
 class QuerySet(BaseQuerySet):
     def __init__(self, model=None, query=None, using=None):
         if not query:
 
 class QuerySet(BaseQuerySet):
     def __init__(self, model=None, query=None, using=None):
         if not query:
-            import inspect
-            spec = inspect.getargspec(BaseQuery.__init__)
-            if len(spec[0]) == 3:
-                # django 1.2
-                query = Query(model, WhereNode)
-            else:
-                # django 1.1
-                query = Query(model, None, WhereNode)
-        super(QuerySet, self).__init__(model=model, query=query)
+            query = Query(model, WhereNode)
+        super(QuerySet, self).__init__(model=model, query=query, using=using)
 
     def delete(self):
         "Bulk deletion."
 
     def delete(self):
         "Bulk deletion."
-        vals = ldapdb.connection.search_s(
-            self.model.base_dn,
-            ldap.SCOPE_SUBTREE,
-            filterstr=self.query._ldap_filter(),
-            attrlist=[],
-        )
+        try:
+            vals = ldapdb.connection.search_s(
+                self.model.base_dn,
+                self.model.search_scope,
+                filterstr=self.query._ldap_filter(),
+                attrlist=[],
+            )
+        except ldap.NO_SUCH_OBJECT:
+            return
+
         # FIXME : there is probably a more efficient way to do this 
         for dn, attrs in vals:
             ldapdb.connection.delete_s(dn)
         # FIXME : there is probably a more efficient way to do this 
         for dn, attrs in vals:
             ldapdb.connection.delete_s(dn)