From: jlaine Date: Mon, 11 Apr 2011 09:49:56 +0000 (+0000) Subject: move LDAP filter compiling X-Git-Url: https://git.stderr.nl/gitweb?p=matthijs%2Fupstream%2Fdjango-ldapdb.git;a=commitdiff_plain;h=34230614a161ca1d9b43af5d528d64f10dd5ea73 move LDAP filter compiling git-svn-id: https://svn.bolloretelecom.eu/opensource/django-ldapdb/trunk@1021 e071eeec-0327-468d-9b6a-08194a12b294 --- diff --git a/ldapdb/backends/ldap/compiler.py b/ldapdb/backends/ldap/compiler.py index 69417a5..df3247f 100644 --- a/ldapdb/backends/ldap/compiler.py +++ b/ldapdb/backends/ldap/compiler.py @@ -32,9 +32,52 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +import ldap + from django.db.models.sql import compiler -import ldap +def get_lookup_operator(lookup_type): + if lookup_type == 'gte': + return '>=' + elif lookup_type == 'lte': + return '<=' + else: + return '=' + +def where_as_sql(self, qn=None, connection=None): + bits = [] + for item in self.children: + if hasattr(item, 'as_sql'): + sql, params = where_as_sql(item, qn=qn, connection=connection) + bits.append(sql) + continue + + constraint, lookup_type, y, values = item + comp = get_lookup_operator(lookup_type) + if lookup_type == 'in': + equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ] + clause = '(|%s)' % ''.join(equal_bits) + else: + clause = "(%s%s%s)" % (constraint.col, comp, values) + + bits.append(clause) + + if not len(bits): + return '', [] + + if len(bits) == 1: + sql_string = bits[0] + elif self.connector == AND: + sql_string = '(&%s)' % ''.join(bits) + elif self.connector == OR: + sql_string = '(|%s)' % ''.join(bits) + else: + raise Exception("Unhandled WHERE connector: %s" % self.connector) + + if self.negated: + sql_string = ('(!%s)' % sql_string) + + return sql_string, [] class SQLCompiler(object): def __init__(self, query, connection, using): @@ -42,6 +85,12 @@ class SQLCompiler(object): self.connection = connection self.using = using + def _ldap_filter(self): + filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.query.model.object_classes]) + sql, params = where_as_sql(self.query.where) + filterstr += sql + return '(&%s)' % filterstr + def results_iter(self): if self.query.select_fields: fields = self.query.select_fields @@ -54,7 +103,7 @@ class SQLCompiler(object): vals = self.connection.search_s( self.query.model.base_dn, self.query.model.search_scope, - filterstr=self.query._ldap_filter(), + filterstr=self._ldap_filter(), attrlist=attrlist, ) except ldap.NO_SUCH_OBJECT: @@ -113,7 +162,20 @@ class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): pass class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): - pass + def execute_sql(self, result_type=compiler.MULTI): + try: + vals = self.connection.search_s( + self.query.model.base_dn, + self.query.model.search_scope, + filterstr=self._ldap_filter(), + attrlist=[], + ) + except ldap.NO_SUCH_OBJECT: + return + + # FIXME : there is probably a more efficient way to do this + for dn, attrs in vals: + self.connection.delete_s(dn) class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): pass