X-Git-Url: https://git.stderr.nl/gitweb?p=matthijs%2Fupstream%2Fdjango-ldapdb.git;a=blobdiff_plain;f=ldapdb%2Fbackends%2Fldap%2Fcompiler.py;h=0b6eb7f5336e2b581e9e9097b7b3db38628991e1;hp=69417a5ef6cbd0c5e32423e93cc5732adb77bbca;hb=128dc2edeee9d41b7288f0ef3e3e6183aec70a69;hpb=2d6433b98c090eae864447b1985465a3ce15805e diff --git a/ldapdb/backends/ldap/compiler.py b/ldapdb/backends/ldap/compiler.py index 69417a5..0b6eb7f 100644 --- a/ldapdb/backends/ldap/compiler.py +++ b/ldapdb/backends/ldap/compiler.py @@ -32,16 +32,87 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -from django.db.models.sql import compiler - import ldap +from django.db.models.sql import aggregates, compiler +from django.db.models.sql.where import AND, OR + +def get_lookup_operator(lookup_type): + if lookup_type == 'gte': + return '>=' + elif lookup_type == 'lte': + return '<=' + else: + return '=' + +def query_as_ldap(query): + filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes]) + sql, params = where_as_ldap(query.where) + filterstr += sql + return '(&%s)' % filterstr + +def where_as_ldap(self): + bits = [] + for item in self.children: + if hasattr(item, 'as_sql'): + sql, params = where_as_ldap(item) + bits.append(sql) + continue + + constraint, lookup_type, y, values = item + comp = get_lookup_operator(lookup_type) + if lookup_type == 'in': + equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ] + clause = '(|%s)' % ''.join(equal_bits) + else: + clause = "(%s%s%s)" % (constraint.col, comp, values) + + bits.append(clause) + + if not len(bits): + return '', [] + + if len(bits) == 1: + sql_string = bits[0] + elif self.connector == AND: + sql_string = '(&%s)' % ''.join(bits) + elif self.connector == OR: + sql_string = '(|%s)' % ''.join(bits) + else: + raise Exception("Unhandled WHERE connector: %s" % self.connector) + + if self.negated: + sql_string = ('(!%s)' % sql_string) + + return sql_string, [] + class SQLCompiler(object): def __init__(self, query, connection, using): self.query = query self.connection = connection self.using = using + def execute_sql(self, result_type=compiler.MULTI): + if result_type !=compiler.SINGLE: + raise Exception("LDAP does not support MULTI queries") + + try: + vals = self.connection.search_s( + self.query.model.base_dn, + self.query.model.search_scope, + filterstr=query_as_ldap(self.query), + attrlist=['dn'], + ) + except ldap.NO_SUCH_OBJECT: + vals = [] + + output = [] + for key, aggregate in self.query.aggregate_select.items(): + if not isinstance(aggregate, aggregates.Count): + raise Exception("Unsupported aggregate %s" % aggregate) + output.append(len(vals)) + return output + def results_iter(self): if self.query.select_fields: fields = self.query.select_fields @@ -54,7 +125,7 @@ class SQLCompiler(object): vals = self.connection.search_s( self.query.model.base_dn, self.query.model.search_scope, - filterstr=self.query._ldap_filter(), + filterstr=query_as_ldap(self.query), attrlist=attrlist, ) except ldap.NO_SUCH_OBJECT: @@ -113,7 +184,20 @@ class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): pass class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): - pass + def execute_sql(self, result_type=compiler.MULTI): + try: + vals = self.connection.search_s( + self.query.model.base_dn, + self.query.model.search_scope, + filterstr=query_as_ldap(self.query), + attrlist=[], + ) + except ldap.NO_SUCH_OBJECT: + return + + # FIXME : there is probably a more efficient way to do this + for dn, attrs in vals: + self.connection.delete_s(dn) class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): pass