move compiler definition
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
index f5c402040f8826ce05a38d5947c30998bbc9591e..35df5df8329da532e84736cbf67751162acf3ae3 100644 (file)
@@ -1,21 +1,35 @@
 # -*- coding: utf-8 -*-
 # 
 # django-ldapdb
 # -*- coding: utf-8 -*-
 # 
 # django-ldapdb
-# Copyright (C) 2009-2010 BollorĂ© telecom
+# Copyright (c) 2009-2010, BollorĂ© telecom
+# All rights reserved.
+# 
 # See AUTHORS file for a full list of contributors.
 # 
 # See AUTHORS file for a full list of contributors.
 # 
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
+# Redistribution and use in source and binary forms, with or without modification,
+# are permitted provided that the following conditions are met:
+# 
+#     1. Redistributions of source code must retain the above copyright notice, 
+#        this list of conditions and the following disclaimer.
+#     
+#     2. Redistributions in binary form must reproduce the above copyright 
+#        notice, this list of conditions and the following disclaimer in the
+#        documentation and/or other materials provided with the distribution.
 # 
 # 
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
+#     3. Neither the name of BollorĂ© telecom nor the names of its contributors
+#        may be used to endorse or promote products derived from this software
+#        without specific prior written permission.
 # 
 # 
-# You should have received a copy of the GNU General Public License
-# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 #
 
 from copy import deepcopy
 #
 
 from copy import deepcopy
@@ -27,7 +41,7 @@ from django.db.models.sql import Query as BaseQuery
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
-
+from ldapdb.backends.ldap import compiler
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
@@ -42,8 +56,12 @@ class Constraint(BaseConstraint):
     """
     An object that can be passed to WhereNode.add() and knows how to
     pre-process itself prior to including in the WhereNode.
     """
     An object that can be passed to WhereNode.add() and knows how to
     pre-process itself prior to including in the WhereNode.
+
+    NOTES: 
+    - we redefine this class, because when self.field is None calls
+    Field().get_db_prep_lookup(), which short-circuits our LDAP-specific code.
     """
     """
-    def process(self, lookup_type, value):
+    def process(self, lookup_type, value, connection):
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
@@ -53,10 +71,12 @@ class Constraint(BaseConstraint):
 
         try:
             if self.field:
 
         try:
             if self.field:
-                params = self.field.get_db_prep_lookup(lookup_type, value)
+                params = self.field.get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
                 db_type = self.field.db_type()
             else:
                 db_type = self.field.db_type()
             else:
-                params = CharField().get_db_prep_lookup(lookup_type, value)
+                params = CharField().get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
@@ -75,37 +95,27 @@ class WhereNode(BaseWhereNode):
             obj = Constraint(obj.alias, obj.col, obj.field)
         super(WhereNode, self).add((obj, lookup_type, value), connector)
 
             obj = Constraint(obj.alias, obj.col, obj.field)
         super(WhereNode, self).add((obj, lookup_type, value), connector)
 
-    def as_sql(self, qn=None):
+    def as_sql(self, qn=None, connection=None):
         bits = []
         for item in self.children:
         bits = []
         for item in self.children:
-            if isinstance(item, WhereNode):
-                sql, params = item.as_sql()
+            if hasattr(item, 'as_sql'):
+                sql, params = item.as_sql(qn=qn, connection=connection)
                 bits.append(sql)
                 continue
 
             constraint, lookup_type, y, values = item
             comp = get_lookup_operator(lookup_type)
                 bits.append(sql)
                 continue
 
             constraint, lookup_type, y, values = item
             comp = get_lookup_operator(lookup_type)
-            if hasattr(constraint, "col"):
-                # django 1.2
-                column = constraint.col
-                if lookup_type == 'in':
-                    equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
-                    clause = '(|%s)' % ''.join(equal_bits)
-                else:
-                    clause = "(%s%s%s)" % (constraint.col, comp, values)
+            if lookup_type == 'in':
+                equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
+                clause = '(|%s)' % ''.join(equal_bits)
             else:
             else:
-                # django 1.1
-                (table, column, db_type) = constraint
-                equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
-                if len(equal_bits) == 1:
-                    clause = equal_bits[0]
-                else:
-                    clause = '(|%s)' % ''.join(equal_bits)
-
-            if self.negated:
-                bits.append('(!%s)' % clause)
-            else:
-                bits.append(clause)
+                clause = "(%s%s%s)" % (constraint.col, comp, values)
+
+            bits.append(clause)
+
+        if not len(bits):
+            return '', []
+
         if len(bits) == 1:
             sql_string = bits[0]
         elif self.connector == AND:
         if len(bits) == 1:
             sql_string = bits[0]
         elif self.connector == AND:
@@ -114,6 +124,10 @@ class WhereNode(BaseWhereNode):
             sql_string = '(|%s)' % ''.join(bits)
         else:
             raise Exception("Unhandled WHERE connector: %s" % self.connector)
             sql_string = '(|%s)' % ''.join(bits)
         else:
             raise Exception("Unhandled WHERE connector: %s" % self.connector)
+
+        if self.negated:
+            sql_string = ('(!%s)' % sql_string)
+
         return sql_string, []
 
 class Query(BaseQuery):
         return sql_string, []
 
 class Query(BaseQuery):
@@ -121,84 +135,57 @@ class Query(BaseQuery):
         super(Query, self).__init__(*args, **kwargs)
         self.connection = ldapdb.connection
 
         super(Query, self).__init__(*args, **kwargs)
         self.connection = ldapdb.connection
 
-    def get_count(self):
+    def _ldap_filter(self):
         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
         sql, params = self.where.as_sql()
         filterstr += sql
         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
         sql, params = self.where.as_sql()
         filterstr += sql
-        filterstr = '(&%s)' % filterstr
+        return '(&%s)' % filterstr
 
 
+    def get_count(self, using):
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
-                ldap.SCOPE_SUBTREE,
-                filterstr=filterstr,
+                self.model.search_scope,
+                filterstr=self._ldap_filter(),
                 attrlist=[],
             )
                 attrlist=[],
             )
-        except:
-            raise self.model.DoesNotExist
+        except ldap.NO_SUCH_OBJECT:
+            return 0
 
 
-        return len(vals)
+        number = len(vals)
 
 
-    def results_iter(self):
-        filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
-        sql, params = self.where.as_sql()
-        filterstr += sql
-        filterstr = '(&%s)' % filterstr
-        attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
+        # apply limit and offset
+        number = max(0, number - self.low_mark)
+        if self.high_mark is not None:
+            number = min(number, self.high_mark - self.low_mark)
+
+        return number
 
 
+    def get_compiler(self, using=None, connection=None):
+        return compiler.SQLCompiler(self, ldapdb.connection, using)
+
+    def has_results(self, using):
+        return self.get_count(using) != 0
+
+class QuerySet(BaseQuerySet):
+    def __init__(self, model=None, query=None, using=None):
+        if not query:
+            query = Query(model, WhereNode)
+        super(QuerySet, self).__init__(model=model, query=query, using=using)
+
+    def delete(self):
+        "Bulk deletion."
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
-                ldap.SCOPE_SUBTREE,
-                filterstr=filterstr,
-                attrlist=attrlist,
+                self.model.search_scope,
+                filterstr=self.query._ldap_filter(),
+                attrlist=[],
             )
             )
-        except:
-            raise self.model.DoesNotExist
-
-        # perform sorting
-        if self.extra_order_by:
-            ordering = self.extra_order_by
-        elif not self.default_ordering:
-            ordering = self.order_by
-        else:
-            ordering = self.order_by or self.model._meta.ordering
-        def cmpvals(x, y):
-            for field in ordering:
-                if field.startswith('-'):
-                    field = field[1:]
-                    negate = True
-                else:
-                    negate = False
-                attr = self.model._meta.get_field(field).db_column
-                attr_x = x[1].get(attr, '').lower()
-                attr_y = y[1].get(attr, '').lower()
-                val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
-                if val:
-                    return val
-            return 0
-        vals = sorted(vals, cmp=cmpvals)
+        except ldap.NO_SUCH_OBJECT:
+            return
 
 
-        # process results
+        # FIXME : there is probably a more efficient way to do this 
         for dn, attrs in vals:
         for dn, attrs in vals:
-            row = []
-            for field in iter(self.model._meta.fields):
-                if field.attname == 'dn':
-                    row.append(dn)
-                else:
-                    row.append(attrs.get(field.db_column, None))
-            yield row
-
-class QuerySet(BaseQuerySet):
-    def __init__(self, model=None, query=None, using=None):
-        if not query:
-            import inspect
-            spec = inspect.getargspec(BaseQuery.__init__)
-            if len(spec[0]) == 3:
-                # django 1.2
-                query = Query(model, WhereNode)
-            else:
-                # django 1.1
-                query = Query(model, None, WhereNode)
-        super(QuerySet, self).__init__(model=model, query=query)
+            ldapdb.connection.delete_s(dn)