add a minimal execute_sql() method to compiler
[matthijs/upstream/django-ldapdb.git] / ldapdb / backends / ldap / compiler.py
index a3cde180ba0d3649ac6ffa9e6992168e35219ea7..0b6eb7f5336e2b581e9e9097b7b3db38628991e1 100644 (file)
 
 import ldap
 
+from django.db.models.sql import aggregates, compiler
+from django.db.models.sql.where import AND, OR
+
+def get_lookup_operator(lookup_type):
+    if lookup_type == 'gte':
+        return '>='
+    elif lookup_type == 'lte':
+        return '<='
+    else:
+        return '='
+
+def query_as_ldap(query):
+    filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes])
+    sql, params = where_as_ldap(query.where)
+    filterstr += sql
+    return '(&%s)' % filterstr
+
+def where_as_ldap(self):
+    bits = []
+    for item in self.children:
+        if hasattr(item, 'as_sql'):
+            sql, params = where_as_ldap(item)
+            bits.append(sql)
+            continue
+
+        constraint, lookup_type, y, values = item
+        comp = get_lookup_operator(lookup_type)
+        if lookup_type == 'in':
+            equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
+            clause = '(|%s)' % ''.join(equal_bits)
+        else:
+            clause = "(%s%s%s)" % (constraint.col, comp, values)
+
+        bits.append(clause)
+
+    if not len(bits):
+        return '', []
+
+    if len(bits) == 1:
+        sql_string = bits[0]
+    elif self.connector == AND:
+        sql_string = '(&%s)' % ''.join(bits)
+    elif self.connector == OR:
+        sql_string = '(|%s)' % ''.join(bits)
+    else:
+        raise Exception("Unhandled WHERE connector: %s" % self.connector)
+
+    if self.negated:
+        sql_string = ('(!%s)' % sql_string)
+
+    return sql_string, []
+
 class SQLCompiler(object):
     def __init__(self, query, connection, using):
         self.query = query
         self.connection = connection
         self.using = using
 
+    def execute_sql(self, result_type=compiler.MULTI):
+        if result_type !=compiler.SINGLE:
+            raise Exception("LDAP does not support MULTI queries")
+
+        try:
+            vals = self.connection.search_s(
+                self.query.model.base_dn,
+                self.query.model.search_scope,
+                filterstr=query_as_ldap(self.query),
+                attrlist=['dn'],
+            )
+        except ldap.NO_SUCH_OBJECT:
+            vals = []
+
+        output = []
+        for key, aggregate in self.query.aggregate_select.items():
+            if not isinstance(aggregate, aggregates.Count):
+                raise Exception("Unsupported aggregate %s" % aggregate)
+            output.append(len(vals))
+        return output
+
     def results_iter(self):
         if self.query.select_fields:
             fields = self.query.select_fields
@@ -52,7 +125,7 @@ class SQLCompiler(object):
             vals = self.connection.search_s(
                 self.query.model.base_dn,
                 self.query.model.search_scope,
-                filterstr=self.query._ldap_filter(),
+                filterstr=query_as_ldap(self.query),
                 attrlist=attrlist,
             )
         except ldap.NO_SUCH_OBJECT:
@@ -107,3 +180,31 @@ class SQLCompiler(object):
             yield row
             pos += 1
 
+class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
+    pass
+
+class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
+    def execute_sql(self, result_type=compiler.MULTI):
+        try:
+            vals = self.connection.search_s(
+                self.query.model.base_dn,
+                self.query.model.search_scope,
+                filterstr=query_as_ldap(self.query),
+                attrlist=[],
+            )
+        except ldap.NO_SUCH_OBJECT:
+            return
+
+        # FIXME : there is probably a more efficient way to do this 
+        for dn, attrs in vals:
+            self.connection.delete_s(dn)
+
+class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
+    pass
+
+class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
+    pass
+
+class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
+    pass
+