X-Git-Url: https://git.stderr.nl/gitweb?p=matthijs%2Fupstream%2Fdjango-ldapdb.git;a=blobdiff_plain;f=ldapdb%2Fbackends%2Fldap%2Fcompiler.py;h=df3247f76b720d8d7c4c701db835496027dc0c42;hp=a3cde180ba0d3649ac6ffa9e6992168e35219ea7;hb=34230614a161ca1d9b43af5d528d64f10dd5ea73;hpb=eaafba42e8cb7e21e68ab0a4be0a76d326818b91 diff --git a/ldapdb/backends/ldap/compiler.py b/ldapdb/backends/ldap/compiler.py index a3cde18..df3247f 100644 --- a/ldapdb/backends/ldap/compiler.py +++ b/ldapdb/backends/ldap/compiler.py @@ -34,12 +34,63 @@ import ldap +from django.db.models.sql import compiler + +def get_lookup_operator(lookup_type): + if lookup_type == 'gte': + return '>=' + elif lookup_type == 'lte': + return '<=' + else: + return '=' + +def where_as_sql(self, qn=None, connection=None): + bits = [] + for item in self.children: + if hasattr(item, 'as_sql'): + sql, params = where_as_sql(item, qn=qn, connection=connection) + bits.append(sql) + continue + + constraint, lookup_type, y, values = item + comp = get_lookup_operator(lookup_type) + if lookup_type == 'in': + equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ] + clause = '(|%s)' % ''.join(equal_bits) + else: + clause = "(%s%s%s)" % (constraint.col, comp, values) + + bits.append(clause) + + if not len(bits): + return '', [] + + if len(bits) == 1: + sql_string = bits[0] + elif self.connector == AND: + sql_string = '(&%s)' % ''.join(bits) + elif self.connector == OR: + sql_string = '(|%s)' % ''.join(bits) + else: + raise Exception("Unhandled WHERE connector: %s" % self.connector) + + if self.negated: + sql_string = ('(!%s)' % sql_string) + + return sql_string, [] + class SQLCompiler(object): def __init__(self, query, connection, using): self.query = query self.connection = connection self.using = using + def _ldap_filter(self): + filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.query.model.object_classes]) + sql, params = where_as_sql(self.query.where) + filterstr += sql + return '(&%s)' % filterstr + def results_iter(self): if self.query.select_fields: fields = self.query.select_fields @@ -52,7 +103,7 @@ class SQLCompiler(object): vals = self.connection.search_s( self.query.model.base_dn, self.query.model.search_scope, - filterstr=self.query._ldap_filter(), + filterstr=self._ldap_filter(), attrlist=attrlist, ) except ldap.NO_SUCH_OBJECT: @@ -107,3 +158,31 @@ class SQLCompiler(object): yield row pos += 1 +class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): + pass + +class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): + def execute_sql(self, result_type=compiler.MULTI): + try: + vals = self.connection.search_s( + self.query.model.base_dn, + self.query.model.search_scope, + filterstr=self._ldap_filter(), + attrlist=[], + ) + except ldap.NO_SUCH_OBJECT: + return + + # FIXME : there is probably a more efficient way to do this + for dn, attrs in vals: + self.connection.delete_s(dn) + +class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): + pass + +class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): + pass + +class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler): + pass +