X-Git-Url: https://git.stderr.nl/gitweb?a=blobdiff_plain;f=ldapdb%2Fmodels%2Fquery.py;h=55576a6b4be1c44cb8427e49aff7c3eb29e4b42a;hb=6035af62fccb4c3f623fd5be0072281f953790f5;hp=020d4eeb7d4a9bf3d03c6242ad296ade8789e0a2;hpb=3912bd6478e2c17a4ea171840a50d1d78ca2694c;p=matthijs%2Fupstream%2Fdjango-ldapdb.git diff --git a/ldapdb/models/query.py b/ldapdb/models/query.py index 020d4ee..55576a6 100644 --- a/ldapdb/models/query.py +++ b/ldapdb/models/query.py @@ -28,7 +28,7 @@ from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as import ldapdb -from ldapdb.models.fields import CharField +from ldapdb.models.fields import CharField, Integer, ListField def get_lookup_operator(lookup_type): if lookup_type == 'gte': @@ -75,11 +75,12 @@ class WhereNode(BaseWhereNode): obj = Constraint(obj.alias, obj.col, obj.field) super(WhereNode, self).add((obj, lookup_type, value), connector) - def as_sql(self): + def as_sql(self, qn=None, connection=None): bits = [] for item in self.children: - if isinstance(item, WhereNode): - bits.append(item.as_sql()) + if hasattr(item, 'as_sql'): + sql, params = item.as_sql(qn=qn, connection=connection) + bits.append(sql) continue constraint, lookup_type, y, values = item @@ -106,19 +107,42 @@ class WhereNode(BaseWhereNode): else: bits.append(clause) if len(bits) == 1: - return bits[0] + sql_string = bits[0] elif self.connector == AND: - return '(&%s)' % ''.join(bits) + sql_string = '(&%s)' % ''.join(bits) elif self.connector == OR: - return '(|%s)' % ''.join(bits) + sql_string = '(|%s)' % ''.join(bits) else: raise Exception("Unhandled WHERE connector: %s" % self.connector) + return sql_string, [] class Query(BaseQuery): + def __init__(self, *args, **kwargs): + super(Query, self).__init__(*args, **kwargs) + self.connection = ldapdb.connection + + def get_count(self): + filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes]) + sql, params = self.where.as_sql() + filterstr += sql + filterstr = '(&%s)' % filterstr + + try: + vals = ldapdb.connection.search_s( + self.model.base_dn, + ldap.SCOPE_SUBTREE, + filterstr=filterstr, + attrlist=[], + ) + except: + raise self.model.DoesNotExist + + return len(vals) + def results_iter(self): - # FIXME: use all object classes - filterstr = '(objectClass=%s)' % self.model.object_classes[0] - filterstr += self.where.as_sql() + filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes]) + sql, params = self.where.as_sql() + filterstr += sql filterstr = '(&%s)' % filterstr attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ] @@ -139,13 +163,21 @@ class Query(BaseQuery): ordering = self.order_by else: ordering = self.order_by or self.model._meta.ordering - def getkey(x): - keys = [] - for k in ordering: - attr = self.model._meta.get_field(k).db_column - keys.append(x[1].get(attr, '').lower()) - return keys - vals = sorted(vals, key=lambda x: getkey(x)) + def cmpvals(x, y): + for field in ordering: + if field.startswith('-'): + field = field[1:] + negate = True + else: + negate = False + attr = self.model._meta.get_field(field).db_column + attr_x = x[1].get(attr, '').lower() + attr_y = y[1].get(attr, '').lower() + val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y) + if val: + return val + return 0 + vals = sorted(vals, cmp=cmpvals) # process results for dn, attrs in vals: @@ -153,20 +185,24 @@ class Query(BaseQuery): for field in iter(self.model._meta.fields): if field.attname == 'dn': row.append(dn) + elif isinstance(field, IntegerField): + row.append(int(attrs.get(field.db_column, [0])[0])) + elif isinstance(field, ListField): + row.append(attrs.get(field.db_column, [])) else: - row.append(attrs.get(field.db_column, None)) + row.append(attrs.get(field.db_column, [''])[0]) yield row class QuerySet(BaseQuerySet): def __init__(self, model=None, query=None, using=None): if not query: import inspect - spec = inspect.getargspec(Query.__init__) + spec = inspect.getargspec(BaseQuery.__init__) if len(spec[0]) == 3: # django 1.2 query = Query(model, WhereNode) else: # django 1.1 query = Query(model, None, WhereNode) - super(QuerySet, self).__init__(model, query) + super(QuerySet, self).__init__(model=model, query=query)