fix "exclude" operations
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (c) 2009-2010, Bolloré telecom
5 # All rights reserved.
6
7 # See AUTHORS file for a full list of contributors.
8
9 # Redistribution and use in source and binary forms, with or without modification,
10 # are permitted provided that the following conditions are met:
11
12 #     1. Redistributions of source code must retain the above copyright notice, 
13 #        this list of conditions and the following disclaimer.
14 #     
15 #     2. Redistributions in binary form must reproduce the above copyright 
16 #        notice, this list of conditions and the following disclaimer in the
17 #        documentation and/or other materials provided with the distribution.
18
19 #     3. Neither the name of Bolloré telecom nor the names of its contributors
20 #        may be used to endorse or promote products derived from this software
21 #        without specific prior written permission.
22
23 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
24 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
25 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
27 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
30 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 #
34
35 from copy import deepcopy
36 import ldap
37
38 from django.db.models.query import QuerySet as BaseQuerySet
39 from django.db.models.query_utils import Q
40 from django.db.models.sql import Query as BaseQuery
41 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
42
43 import ldapdb
44
45 from ldapdb.models.fields import CharField
46
47 def get_lookup_operator(lookup_type):
48     if lookup_type == 'gte':
49         return '>='
50     elif lookup_type == 'lte':
51         return '<='
52     else:
53         return '='
54
55 class Constraint(BaseConstraint):
56     """
57     An object that can be passed to WhereNode.add() and knows how to
58     pre-process itself prior to including in the WhereNode.
59     """
60     def process(self, lookup_type, value):
61         """
62         Returns a tuple of data suitable for inclusion in a WhereNode
63         instance.
64         """
65         # Because of circular imports, we need to import this here.
66         from django.db.models.base import ObjectDoesNotExist
67
68         try:
69             if self.field:
70                 params = self.field.get_db_prep_lookup(lookup_type, value)
71                 db_type = self.field.db_type()
72             else:
73                 params = CharField().get_db_prep_lookup(lookup_type, value)
74                 db_type = None
75         except ObjectDoesNotExist:
76             raise EmptyShortCircuit
77
78         return (self.alias, self.col, db_type), params
79
80 class Compiler(object):
81     def __init__(self, query, connection, using):
82         self.query = query
83         self.connection = connection
84         self.using = using
85
86     def results_iter(self):
87         if self.query.select_fields:
88             fields = self.query.select_fields
89         else:
90             fields = self.query.model._meta.fields
91
92         attrlist = [ x.db_column for x in fields if x.db_column ]
93
94         try:
95             vals = self.connection.search_s(
96                 self.query.model.base_dn,
97                 self.query.model.search_scope,
98                 filterstr=self.query._ldap_filter(),
99                 attrlist=attrlist,
100             )
101         except ldap.NO_SUCH_OBJECT:
102             return
103
104         # perform sorting
105         if self.query.extra_order_by:
106             ordering = self.query.extra_order_by
107         elif not self.query.default_ordering:
108             ordering = self.query.order_by
109         else:
110             ordering = self.query.order_by or self.query.model._meta.ordering
111         def cmpvals(x, y):
112             for fieldname in ordering:
113                 if fieldname.startswith('-'):
114                     fieldname = fieldname[1:]
115                     negate = True
116                 else:
117                     negate = False
118                 field = self.query.model._meta.get_field(fieldname)
119                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
120                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
121                 # perform case insensitive comparison
122                 if hasattr(attr_x, 'lower'):
123                     attr_x = attr_x.lower()
124                 if hasattr(attr_y, 'lower'):
125                     attr_y = attr_y.lower()
126                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
127                 if val:
128                     return val
129             return 0
130         vals = sorted(vals, cmp=cmpvals)
131
132         # process results
133         pos = 0
134         for dn, attrs in vals:
135             # FIXME : This is not optimal, we retrieve more results than we need
136             # but there is probably no other options as we can't perform ordering
137             # server side.
138             if (self.query.low_mark and pos < self.query.low_mark) or \
139                (self.query.high_mark is not None and pos >= self.query.high_mark):
140                 pos += 1
141                 continue
142             row = []
143             for field in iter(fields):
144                 if field.attname == 'dn':
145                     row.append(dn)
146                 elif hasattr(field, 'from_ldap'):
147                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
148                 else:
149                     row.append(None)
150             yield row
151             pos += 1
152
153
154 class WhereNode(BaseWhereNode):
155     def add(self, data, connector):
156         if not isinstance(data, (list, tuple)):
157             super(WhereNode, self).add(data, connector)
158             return
159
160         # we replace the native Constraint by our own
161         obj, lookup_type, value = data
162         if hasattr(obj, "process"):
163             obj = Constraint(obj.alias, obj.col, obj.field)
164         super(WhereNode, self).add((obj, lookup_type, value), connector)
165
166     def as_sql(self, qn=None, connection=None):
167         bits = []
168         for item in self.children:
169             if hasattr(item, 'as_sql'):
170                 sql, params = item.as_sql(qn=qn, connection=connection)
171                 bits.append(sql)
172                 continue
173
174             constraint, lookup_type, y, values = item
175             comp = get_lookup_operator(lookup_type)
176             if hasattr(constraint, "col"):
177                 # django 1.2
178                 column = constraint.col
179                 if lookup_type == 'in':
180                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
181                     clause = '(|%s)' % ''.join(equal_bits)
182                 else:
183                     clause = "(%s%s%s)" % (constraint.col, comp, values)
184             else:
185                 # django 1.1
186                 (table, column, db_type) = constraint
187                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
188                 if len(equal_bits) == 1:
189                     clause = equal_bits[0]
190                 else:
191                     clause = '(|%s)' % ''.join(equal_bits)
192
193             bits.append(clause)
194
195         if not len(bits):
196             return '', []
197
198         if len(bits) == 1:
199             sql_string = bits[0]
200         elif self.connector == AND:
201             sql_string = '(&%s)' % ''.join(bits)
202         elif self.connector == OR:
203             sql_string = '(|%s)' % ''.join(bits)
204         else:
205             raise Exception("Unhandled WHERE connector: %s" % self.connector)
206
207         if self.negated:
208             sql_string = ('(!%s)' % sql_string)
209
210         return sql_string, []
211
212 class Query(BaseQuery):
213     def __init__(self, *args, **kwargs):
214         super(Query, self).__init__(*args, **kwargs)
215         self.connection = ldapdb.connection
216
217     def _ldap_filter(self):
218         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
219         sql, params = self.where.as_sql()
220         filterstr += sql
221         return '(&%s)' % filterstr
222
223     def get_count(self, using=None):
224         try:
225             vals = ldapdb.connection.search_s(
226                 self.model.base_dn,
227                 self.model.search_scope,
228                 filterstr=self._ldap_filter(),
229                 attrlist=[],
230             )
231         except ldap.NO_SUCH_OBJECT:
232             return 0
233
234         number = len(vals)
235
236         # apply limit and offset
237         number = max(0, number - self.low_mark)
238         if self.high_mark is not None:
239             number = min(number, self.high_mark - self.low_mark)
240
241         return number
242
243     def get_compiler(self, using=None, connection=None):
244         return Compiler(self, ldapdb.connection, using)
245
246     def has_results(self, using):
247         return self.get_count() != 0
248
249     def results_iter(self):
250         "For django 1.1 compatibility"
251         return self.get_compiler().results_iter()
252
253 class QuerySet(BaseQuerySet):
254     def __init__(self, model=None, query=None, using=None):
255         if not query:
256             import inspect
257             spec = inspect.getargspec(BaseQuery.__init__)
258             if len(spec[0]) == 3:
259                 # django 1.2
260                 query = Query(model, WhereNode)
261             else:
262                 # django 1.1
263                 query = Query(model, None, WhereNode)
264         super(QuerySet, self).__init__(model=model, query=query)
265
266     def delete(self):
267         "Bulk deletion."
268         try:
269             vals = ldapdb.connection.search_s(
270                 self.model.base_dn,
271                 self.model.search_scope,
272                 filterstr=self.query._ldap_filter(),
273                 attrlist=[],
274             )
275         except ldap.NO_SUCH_OBJECT:
276             return
277
278         # FIXME : there is probably a more efficient way to do this 
279         for dn, attrs in vals:
280             ldapdb.connection.delete_s(dn)
281