X-Git-Url: https://git.stderr.nl/gitweb?p=matthijs%2Fupstream%2Fdjango-ldapdb.git;a=blobdiff_plain;f=ldapdb%2Fbackends%2Fldap%2Fcompiler.py;h=3636ede49732edf247d1580a7e5340cfb807e2d4;hp=7d81981adec2bc95cdde5aedfe992b0946981b42;hb=HEAD;hpb=0dcc0b21337a6bf4859c2a4f13c1951accf7639e diff --git a/ldapdb/backends/ldap/compiler.py b/ldapdb/backends/ldap/compiler.py index 7d81981..3636ede 100644 --- a/ldapdb/backends/ldap/compiler.py +++ b/ldapdb/backends/ldap/compiler.py @@ -34,7 +34,7 @@ import ldap -from django.db.models.sql import compiler +from django.db.models.sql import aggregates, compiler from django.db.models.sql.where import AND, OR def get_lookup_operator(lookup_type): @@ -46,7 +46,12 @@ def get_lookup_operator(lookup_type): return '=' def query_as_ldap(query): - filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes]) + # TODO: Filtering on objectClass temporarily disabled, since this + # breaks Model.save() after an objectclass was added (it queries the + # database for the old values to see what changed, but filtering on + # the new objectclasses does not return the object). + #filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes]) + filterstr = '' sql, params = where_as_ldap(query.where) filterstr += sql return '(&%s)' % filterstr @@ -92,6 +97,37 @@ class SQLCompiler(object): self.connection = connection self.using = using + def execute_sql(self, result_type=compiler.MULTI): + if result_type !=compiler.SINGLE: + raise Exception("LDAP does not support MULTI queries") + + for key, aggregate in self.query.aggregate_select.items(): + if not isinstance(aggregate, aggregates.Count): + raise Exception("Unsupported aggregate %s" % aggregate) + + try: + vals = self.connection.search_s( + self.query.model.base_dn, + self.query.model.search_scope, + filterstr=query_as_ldap(self.query), + attrlist=['dn'], + ) + except ldap.NO_SUCH_OBJECT: + vals = [] + + if not vals: + return None + + output = [] + for alias, col in self.query.extra_select.iteritems(): + output.append(col[0]) + for key, aggregate in self.query.aggregate_select.items(): + if isinstance(aggregate, aggregates.Count): + output.append(len(vals)) + else: + output.append(None) + return output + def results_iter(self): if self.query.select_fields: fields = self.query.select_fields @@ -169,7 +205,7 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): self.query.model.base_dn, self.query.model.search_scope, filterstr=query_as_ldap(self.query), - attrlist=[], + attrlist=['dn'], ) except ldap.NO_SUCH_OBJECT: return