allow order_by using a '-' prefix
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class WhereNode(BaseWhereNode):
67     def add(self, data, connector):
68         if not isinstance(data, (list, tuple)):
69             super(WhereNode, self).add(data, connector)
70             return
71
72         # we replace the native Constraint by our own
73         obj, lookup_type, value = data
74         if hasattr(obj, "process"):
75             obj = Constraint(obj.alias, obj.col, obj.field)
76         super(WhereNode, self).add((obj, lookup_type, value), connector)
77
78     def as_sql(self):
79         bits = []
80         for item in self.children:
81             if isinstance(item, WhereNode):
82                 bits.append(item.as_sql())
83                 continue
84
85             constraint, lookup_type, y, values = item
86             comp = get_lookup_operator(lookup_type)
87             if hasattr(constraint, "col"):
88                 # django 1.2
89                 column = constraint.col
90                 if lookup_type == 'in':
91                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
92                     clause = '(|%s)' % ''.join(equal_bits)
93                 else:
94                     clause = "(%s%s%s)" % (constraint.col, comp, values)
95             else:
96                 # django 1.1
97                 (table, column, db_type) = constraint
98                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
99                 if len(equal_bits) == 1:
100                     clause = equal_bits[0]
101                 else:
102                     clause = '(|%s)' % ''.join(equal_bits)
103
104             if self.negated:
105                 bits.append('(!%s)' % clause)
106             else:
107                 bits.append(clause)
108         if len(bits) == 1:
109             return bits[0]
110         elif self.connector == AND:
111             return '(&%s)' % ''.join(bits)
112         elif self.connector == OR:
113             return '(|%s)' % ''.join(bits)
114         else:
115             raise Exception("Unhandled WHERE connector: %s" % self.connector)
116
117 class Query(BaseQuery):
118     def results_iter(self):
119         # FIXME: use all object classes
120         filterstr = '(objectClass=%s)' % self.model.object_classes[0]
121         filterstr += self.where.as_sql()
122         filterstr = '(&%s)' % filterstr
123         attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
124
125         try:
126             vals = ldapdb.connection.search_s(
127                 self.model.base_dn,
128                 ldap.SCOPE_SUBTREE,
129                 filterstr=filterstr,
130                 attrlist=attrlist,
131             )
132         except:
133             raise self.model.DoesNotExist
134
135         # perform sorting
136         if self.extra_order_by:
137             ordering = self.extra_order_by
138         elif not self.default_ordering:
139             ordering = self.order_by
140         else:
141             ordering = self.order_by or self.model._meta.ordering
142         def cmpvals(x, y):
143             for field in ordering:
144                 if field.startswith('-'):
145                     field = field[1:]
146                     negate = True
147                 else:
148                     negate = False
149                 attr = self.model._meta.get_field(field).db_column
150                 attr_x = x[1].get(attr, '').lower()
151                 attr_y = y[1].get(attr, '').lower()
152                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
153                 if val:
154                     return val
155             return 0
156         vals = sorted(vals, cmp=cmpvals)
157
158         # process results
159         for dn, attrs in vals:
160             row = []
161             for field in iter(self.model._meta.fields):
162                 if field.attname == 'dn':
163                     row.append(dn)
164                 else:
165                     row.append(attrs.get(field.db_column, None))
166             yield row
167
168 class QuerySet(BaseQuerySet):
169     def __init__(self, model=None, query=None, using=None):
170         if not query:
171             import inspect
172             spec = inspect.getargspec(Query.__init__)
173             if len(spec[0]) == 3:
174                 # django 1.2
175                 query = Query(model, WhereNode)
176             else:
177                 # django 1.1
178                 query = Query(model, None, WhereNode)
179         super(QuerySet, self).__init__(model, query)
180