add a minimal execute_sql() method to compiler
[matthijs/upstream/django-ldapdb.git] / ldapdb / backends / ldap / compiler.py
index df3247f76b720d8d7c4c701db835496027dc0c42..0b6eb7f5336e2b581e9e9097b7b3db38628991e1 100644 (file)
@@ -34,7 +34,8 @@
 
 import ldap
 
-from django.db.models.sql import compiler
+from django.db.models.sql import aggregates, compiler
+from django.db.models.sql.where import AND, OR
 
 def get_lookup_operator(lookup_type):
     if lookup_type == 'gte':
@@ -44,11 +45,17 @@ def get_lookup_operator(lookup_type):
     else:
         return '='
 
-def where_as_sql(self, qn=None, connection=None):
+def query_as_ldap(query):
+    filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes])
+    sql, params = where_as_ldap(query.where)
+    filterstr += sql
+    return '(&%s)' % filterstr
+
+def where_as_ldap(self):
     bits = []
     for item in self.children:
         if hasattr(item, 'as_sql'):
-            sql, params = where_as_sql(item, qn=qn, connection=connection)
+            sql, params = where_as_ldap(item)
             bits.append(sql)
             continue
 
@@ -85,11 +92,26 @@ class SQLCompiler(object):
         self.connection = connection
         self.using = using
 
-    def _ldap_filter(self):
-        filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.query.model.object_classes])
-        sql, params = where_as_sql(self.query.where)
-        filterstr += sql
-        return '(&%s)' % filterstr
+    def execute_sql(self, result_type=compiler.MULTI):
+        if result_type !=compiler.SINGLE:
+            raise Exception("LDAP does not support MULTI queries")
+
+        try:
+            vals = self.connection.search_s(
+                self.query.model.base_dn,
+                self.query.model.search_scope,
+                filterstr=query_as_ldap(self.query),
+                attrlist=['dn'],
+            )
+        except ldap.NO_SUCH_OBJECT:
+            vals = []
+
+        output = []
+        for key, aggregate in self.query.aggregate_select.items():
+            if not isinstance(aggregate, aggregates.Count):
+                raise Exception("Unsupported aggregate %s" % aggregate)
+            output.append(len(vals))
+        return output
 
     def results_iter(self):
         if self.query.select_fields:
@@ -103,7 +125,7 @@ class SQLCompiler(object):
             vals = self.connection.search_s(
                 self.query.model.base_dn,
                 self.query.model.search_scope,
-                filterstr=self._ldap_filter(),
+                filterstr=query_as_ldap(self.query),
                 attrlist=attrlist,
             )
         except ldap.NO_SUCH_OBJECT:
@@ -167,7 +189,7 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
             vals = self.connection.search_s(
                 self.query.model.base_dn,
                 self.query.model.search_scope,
-                filterstr=self._ldap_filter(),
+                filterstr=query_as_ldap(self.query),
                 attrlist=[],
             )
         except ldap.NO_SUCH_OBJECT: