move LDAP compilation to the backend
authorjlaine <jlaine@e071eeec-0327-468d-9b6a-08194a12b294>
Mon, 11 Apr 2011 10:09:33 +0000 (10:09 +0000)
committerjlaine <jlaine@e071eeec-0327-468d-9b6a-08194a12b294>
Mon, 11 Apr 2011 10:09:33 +0000 (10:09 +0000)
git-svn-id: https://svn.bolloretelecom.eu/opensource/django-ldapdb/trunk@1023 e071eeec-0327-468d-9b6a-08194a12b294

ldapdb/backends/ldap/compiler.py
ldapdb/models/query.py

index df3247f76b720d8d7c4c701db835496027dc0c42..7d81981adec2bc95cdde5aedfe992b0946981b42 100644 (file)
@@ -35,6 +35,7 @@
 import ldap
 
 from django.db.models.sql import compiler
 import ldap
 
 from django.db.models.sql import compiler
+from django.db.models.sql.where import AND, OR
 
 def get_lookup_operator(lookup_type):
     if lookup_type == 'gte':
 
 def get_lookup_operator(lookup_type):
     if lookup_type == 'gte':
@@ -44,11 +45,17 @@ def get_lookup_operator(lookup_type):
     else:
         return '='
 
     else:
         return '='
 
-def where_as_sql(self, qn=None, connection=None):
+def query_as_ldap(query):
+    filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes])
+    sql, params = where_as_ldap(query.where)
+    filterstr += sql
+    return '(&%s)' % filterstr
+
+def where_as_ldap(self):
     bits = []
     for item in self.children:
         if hasattr(item, 'as_sql'):
     bits = []
     for item in self.children:
         if hasattr(item, 'as_sql'):
-            sql, params = where_as_sql(item, qn=qn, connection=connection)
+            sql, params = where_as_ldap(item)
             bits.append(sql)
             continue
 
             bits.append(sql)
             continue
 
@@ -85,12 +92,6 @@ class SQLCompiler(object):
         self.connection = connection
         self.using = using
 
         self.connection = connection
         self.using = using
 
-    def _ldap_filter(self):
-        filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.query.model.object_classes])
-        sql, params = where_as_sql(self.query.where)
-        filterstr += sql
-        return '(&%s)' % filterstr
-
     def results_iter(self):
         if self.query.select_fields:
             fields = self.query.select_fields
     def results_iter(self):
         if self.query.select_fields:
             fields = self.query.select_fields
@@ -103,7 +104,7 @@ class SQLCompiler(object):
             vals = self.connection.search_s(
                 self.query.model.base_dn,
                 self.query.model.search_scope,
             vals = self.connection.search_s(
                 self.query.model.base_dn,
                 self.query.model.search_scope,
-                filterstr=self._ldap_filter(),
+                filterstr=query_as_ldap(self.query),
                 attrlist=attrlist,
             )
         except ldap.NO_SUCH_OBJECT:
                 attrlist=attrlist,
             )
         except ldap.NO_SUCH_OBJECT:
@@ -167,7 +168,7 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
             vals = self.connection.search_s(
                 self.query.model.base_dn,
                 self.query.model.search_scope,
             vals = self.connection.search_s(
                 self.query.model.base_dn,
                 self.query.model.search_scope,
-                filterstr=self._ldap_filter(),
+                filterstr=query_as_ldap(self.query),
                 attrlist=[],
             )
         except ldap.NO_SUCH_OBJECT:
                 attrlist=[],
             )
         except ldap.NO_SUCH_OBJECT:
index 35df5df8329da532e84736cbf67751162acf3ae3..2b8237699bed2da8cab0f70c631e90abafa12b43 100644 (file)
@@ -44,14 +44,6 @@ import ldapdb
 from ldapdb.backends.ldap import compiler
 from ldapdb.models.fields import CharField
 
 from ldapdb.backends.ldap import compiler
 from ldapdb.models.fields import CharField
 
-def get_lookup_operator(lookup_type):
-    if lookup_type == 'gte':
-        return '>='
-    elif lookup_type == 'lte':
-        return '<='
-    else:
-        return '='
-
 class Constraint(BaseConstraint):
     """
     An object that can be passed to WhereNode.add() and knows how to
 class Constraint(BaseConstraint):
     """
     An object that can be passed to WhereNode.add() and knows how to
@@ -95,58 +87,13 @@ class WhereNode(BaseWhereNode):
             obj = Constraint(obj.alias, obj.col, obj.field)
         super(WhereNode, self).add((obj, lookup_type, value), connector)
 
             obj = Constraint(obj.alias, obj.col, obj.field)
         super(WhereNode, self).add((obj, lookup_type, value), connector)
 
-    def as_sql(self, qn=None, connection=None):
-        bits = []
-        for item in self.children:
-            if hasattr(item, 'as_sql'):
-                sql, params = item.as_sql(qn=qn, connection=connection)
-                bits.append(sql)
-                continue
-
-            constraint, lookup_type, y, values = item
-            comp = get_lookup_operator(lookup_type)
-            if lookup_type == 'in':
-                equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
-                clause = '(|%s)' % ''.join(equal_bits)
-            else:
-                clause = "(%s%s%s)" % (constraint.col, comp, values)
-
-            bits.append(clause)
-
-        if not len(bits):
-            return '', []
-
-        if len(bits) == 1:
-            sql_string = bits[0]
-        elif self.connector == AND:
-            sql_string = '(&%s)' % ''.join(bits)
-        elif self.connector == OR:
-            sql_string = '(|%s)' % ''.join(bits)
-        else:
-            raise Exception("Unhandled WHERE connector: %s" % self.connector)
-
-        if self.negated:
-            sql_string = ('(!%s)' % sql_string)
-
-        return sql_string, []
-
 class Query(BaseQuery):
 class Query(BaseQuery):
-    def __init__(self, *args, **kwargs):
-        super(Query, self).__init__(*args, **kwargs)
-        self.connection = ldapdb.connection
-
-    def _ldap_filter(self):
-        filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
-        sql, params = self.where.as_sql()
-        filterstr += sql
-        return '(&%s)' % filterstr
-
     def get_count(self, using):
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
                 self.model.search_scope,
     def get_count(self, using):
         try:
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
                 self.model.search_scope,
-                filterstr=self._ldap_filter(),
+                filterstr=compiler.query_as_ldap(self),
                 attrlist=[],
             )
         except ldap.NO_SUCH_OBJECT:
                 attrlist=[],
             )
         except ldap.NO_SUCH_OBJECT:
@@ -162,7 +109,7 @@ class Query(BaseQuery):
         return number
 
     def get_compiler(self, using=None, connection=None):
         return number
 
     def get_compiler(self, using=None, connection=None):
-        return compiler.SQLCompiler(self, ldapdb.connection, using)
+        return super(Query, self).get_compiler(connection=ldapdb.connection)
 
     def has_results(self, using):
         return self.get_count(using) != 0
 
     def has_results(self, using):
         return self.get_count(using) != 0
@@ -179,7 +126,7 @@ class QuerySet(BaseQuerySet):
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
                 self.model.search_scope,
             vals = ldapdb.connection.search_s(
                 self.model.base_dn,
                 self.model.search_scope,
-                filterstr=self.query._ldap_filter(),
+                filterstr=compiler.query_as_ldap(self.query),
                 attrlist=[],
             )
         except ldap.NO_SUCH_OBJECT:
                 attrlist=[],
             )
         except ldap.NO_SUCH_OBJECT: