move compiler definition
authorjlaine <jlaine@e071eeec-0327-468d-9b6a-08194a12b294>
Mon, 11 Apr 2011 08:03:52 +0000 (08:03 +0000)
committerjlaine <jlaine@e071eeec-0327-468d-9b6a-08194a12b294>
Mon, 11 Apr 2011 08:03:52 +0000 (08:03 +0000)
git-svn-id: https://svn.bolloretelecom.eu/opensource/django-ldapdb/trunk@1015 e071eeec-0327-468d-9b6a-08194a12b294

ldapdb/models/query.py

index 129e5056df58b23009d5eb8939728e6b849dbdc0..35df5df8329da532e84736cbf67751162acf3ae3 100644 (file)
@@ -41,7 +41,7 @@ from django.db.models.sql import Query as BaseQuery
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
 
 import ldapdb
-
+from ldapdb.backends.ldap import compiler
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
 from ldapdb.models.fields import CharField
 
 def get_lookup_operator(lookup_type):
@@ -83,80 +83,6 @@ class Constraint(BaseConstraint):
 
         return (self.alias, self.col, db_type), params
 
 
         return (self.alias, self.col, db_type), params
 
-class Compiler(object):
-    def __init__(self, query, connection, using):
-        self.query = query
-        self.connection = connection
-        self.using = using
-
-    def results_iter(self):
-        if self.query.select_fields:
-            fields = self.query.select_fields
-        else:
-            fields = self.query.model._meta.fields
-
-        attrlist = [ x.db_column for x in fields if x.db_column ]
-
-        try:
-            vals = self.connection.search_s(
-                self.query.model.base_dn,
-                self.query.model.search_scope,
-                filterstr=self.query._ldap_filter(),
-                attrlist=attrlist,
-            )
-        except ldap.NO_SUCH_OBJECT:
-            return
-
-        # perform sorting
-        if self.query.extra_order_by:
-            ordering = self.query.extra_order_by
-        elif not self.query.default_ordering:
-            ordering = self.query.order_by
-        else:
-            ordering = self.query.order_by or self.query.model._meta.ordering
-        def cmpvals(x, y):
-            for fieldname in ordering:
-                if fieldname.startswith('-'):
-                    fieldname = fieldname[1:]
-                    negate = True
-                else:
-                    negate = False
-                field = self.query.model._meta.get_field(fieldname)
-                attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
-                attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
-                # perform case insensitive comparison
-                if hasattr(attr_x, 'lower'):
-                    attr_x = attr_x.lower()
-                if hasattr(attr_y, 'lower'):
-                    attr_y = attr_y.lower()
-                val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
-                if val:
-                    return val
-            return 0
-        vals = sorted(vals, cmp=cmpvals)
-
-        # process results
-        pos = 0
-        for dn, attrs in vals:
-            # FIXME : This is not optimal, we retrieve more results than we need
-            # but there is probably no other options as we can't perform ordering
-            # server side.
-            if (self.query.low_mark and pos < self.query.low_mark) or \
-               (self.query.high_mark is not None and pos >= self.query.high_mark):
-                pos += 1
-                continue
-            row = []
-            for field in iter(fields):
-                if field.attname == 'dn':
-                    row.append(dn)
-                elif hasattr(field, 'from_ldap'):
-                    row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
-                else:
-                    row.append(None)
-            yield row
-            pos += 1
-
-
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
 class WhereNode(BaseWhereNode):
     def add(self, data, connector):
         if not isinstance(data, (list, tuple)):
@@ -236,7 +162,7 @@ class Query(BaseQuery):
         return number
 
     def get_compiler(self, using=None, connection=None):
         return number
 
     def get_compiler(self, using=None, connection=None):
-        return Compiler(self, ldapdb.connection, using)
+        return compiler.SQLCompiler(self, ldapdb.connection, using)
 
     def has_results(self, using):
         return self.get_count(using) != 0
 
     def has_results(self, using):
         return self.get_count(using) != 0