move LDAP filter compiling
[matthijs/upstream/django-ldapdb.git] / ldapdb / backends / ldap / compiler.py
index 69417a5ef6cbd0c5e32423e93cc5732adb77bbca..df3247f76b720d8d7c4c701db835496027dc0c42 100644 (file)
 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 #
 
+import ldap
+
 from django.db.models.sql import compiler
 
-import ldap
+def get_lookup_operator(lookup_type):
+    if lookup_type == 'gte':
+        return '>='
+    elif lookup_type == 'lte':
+        return '<='
+    else:
+        return '='
+
+def where_as_sql(self, qn=None, connection=None):
+    bits = []
+    for item in self.children:
+        if hasattr(item, 'as_sql'):
+            sql, params = where_as_sql(item, qn=qn, connection=connection)
+            bits.append(sql)
+            continue
+
+        constraint, lookup_type, y, values = item
+        comp = get_lookup_operator(lookup_type)
+        if lookup_type == 'in':
+            equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
+            clause = '(|%s)' % ''.join(equal_bits)
+        else:
+            clause = "(%s%s%s)" % (constraint.col, comp, values)
+
+        bits.append(clause)
+
+    if not len(bits):
+        return '', []
+
+    if len(bits) == 1:
+        sql_string = bits[0]
+    elif self.connector == AND:
+        sql_string = '(&%s)' % ''.join(bits)
+    elif self.connector == OR:
+        sql_string = '(|%s)' % ''.join(bits)
+    else:
+        raise Exception("Unhandled WHERE connector: %s" % self.connector)
+
+    if self.negated:
+        sql_string = ('(!%s)' % sql_string)
+
+    return sql_string, []
 
 class SQLCompiler(object):
     def __init__(self, query, connection, using):
@@ -42,6 +85,12 @@ class SQLCompiler(object):
         self.connection = connection
         self.using = using
 
+    def _ldap_filter(self):
+        filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.query.model.object_classes])
+        sql, params = where_as_sql(self.query.where)
+        filterstr += sql
+        return '(&%s)' % filterstr
+
     def results_iter(self):
         if self.query.select_fields:
             fields = self.query.select_fields
@@ -54,7 +103,7 @@ class SQLCompiler(object):
             vals = self.connection.search_s(
                 self.query.model.base_dn,
                 self.query.model.search_scope,
-                filterstr=self.query._ldap_filter(),
+                filterstr=self._ldap_filter(),
                 attrlist=attrlist,
             )
         except ldap.NO_SUCH_OBJECT:
@@ -113,7 +162,20 @@ class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
     pass
 
 class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
-    pass
+    def execute_sql(self, result_type=compiler.MULTI):
+        try:
+            vals = self.connection.search_s(
+                self.query.model.base_dn,
+                self.query.model.search_scope,
+                filterstr=self._ldap_filter(),
+                attrlist=[],
+            )
+        except ldap.NO_SUCH_OBJECT:
+            return
+
+        # FIXME : there is probably a more efficient way to do this 
+        for dn, attrs in vals:
+            self.connection.delete_s(dn)
 
 class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
     pass