From 6e4fe5321db424327fb2e85db0acfb8e1ef9b76a Mon Sep 17 00:00:00 2001 From: jlaine Date: Mon, 24 May 2010 18:08:17 +0000 Subject: [PATCH] make it possible to use search in admin interface git-svn-id: https://svn.bolloretelecom.eu/opensource/django-ldapdb/trunk@874 e071eeec-0327-468d-9b6a-08194a12b294 --- ldapdb/models/query.py | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/ldapdb/models/query.py b/ldapdb/models/query.py index b35c123..eb35061 100644 --- a/ldapdb/models/query.py +++ b/ldapdb/models/query.py @@ -75,11 +75,12 @@ class WhereNode(BaseWhereNode): obj = Constraint(obj.alias, obj.col, obj.field) super(WhereNode, self).add((obj, lookup_type, value), connector) - def as_sql(self): + def as_sql(self, qn=None): bits = [] for item in self.children: if isinstance(item, WhereNode): - bits.append(item.as_sql()) + sql, params = item.as_sql() + bits.append(sql) continue constraint, lookup_type, y, values = item @@ -106,19 +107,44 @@ class WhereNode(BaseWhereNode): else: bits.append(clause) if len(bits) == 1: - return bits[0] + sql_string = bits[0] elif self.connector == AND: - return '(&%s)' % ''.join(bits) + sql_string = '(&%s)' % ''.join(bits) elif self.connector == OR: - return '(|%s)' % ''.join(bits) + sql_string = '(|%s)' % ''.join(bits) else: raise Exception("Unhandled WHERE connector: %s" % self.connector) + return sql_string, [] class Query(BaseQuery): + def __init__(self, *args, **kwargs): + super(Query, self).__init__(*args, **kwargs) + self.connection = ldapdb.connection + + def get_count(self): + # FIXME: use all object classes + filterstr = '(objectClass=%s)' % self.model.object_classes[0] + sql, params = self.where.as_sql() + filterstr += sql + filterstr = '(&%s)' % filterstr + + try: + vals = ldapdb.connection.search_s( + self.model.base_dn, + ldap.SCOPE_SUBTREE, + filterstr=filterstr, + attrlist=[], + ) + except: + raise self.model.DoesNotExist + + return len(vals) + def results_iter(self): # FIXME: use all object classes filterstr = '(objectClass=%s)' % self.model.object_classes[0] - filterstr += self.where.as_sql() + sql, params = self.where.as_sql() + filterstr += sql filterstr = '(&%s)' % filterstr attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ] @@ -169,12 +195,12 @@ class QuerySet(BaseQuerySet): def __init__(self, model=None, query=None, using=None): if not query: import inspect - spec = inspect.getargspec(Query.__init__) + spec = inspect.getargspec(BaseQuery.__init__) if len(spec[0]) == 3: # django 1.2 query = Query(model, WhereNode) else: # django 1.1 query = Query(model, None, WhereNode) - super(QuerySet, self).__init__(model, query) + super(QuerySet, self).__init__(model=model, query=query) -- 2.30.2