ab5f19ec727b574cc845d44995d1b4371611ead7
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (c) 2009-2010, Bolloré telecom
5 # All rights reserved.
6
7 # See AUTHORS file for a full list of contributors.
8
9 # Redistribution and use in source and binary forms, with or without modification,
10 # are permitted provided that the following conditions are met:
11
12 #     1. Redistributions of source code must retain the above copyright notice, 
13 #        this list of conditions and the following disclaimer.
14 #     
15 #     2. Redistributions in binary form must reproduce the above copyright 
16 #        notice, this list of conditions and the following disclaimer in the
17 #        documentation and/or other materials provided with the distribution.
18
19 #     3. Neither the name of Bolloré telecom nor the names of its contributors
20 #        may be used to endorse or promote products derived from this software
21 #        without specific prior written permission.
22
23 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
24 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
25 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
27 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
30 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 #
34
35 from copy import deepcopy
36 import ldap
37
38 from django.db.models.query import QuerySet as BaseQuerySet
39 from django.db.models.query_utils import Q
40 from django.db.models.sql import Query as BaseQuery
41 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
42
43 import ldapdb
44
45 from ldapdb.models.fields import CharField
46
47 def get_lookup_operator(lookup_type):
48     if lookup_type == 'gte':
49         return '>='
50     elif lookup_type == 'lte':
51         return '<='
52     else:
53         return '='
54
55 class Constraint(BaseConstraint):
56     """
57     An object that can be passed to WhereNode.add() and knows how to
58     pre-process itself prior to including in the WhereNode.
59     """
60     def process(self, lookup_type, value):
61         """
62         Returns a tuple of data suitable for inclusion in a WhereNode
63         instance.
64         """
65         # Because of circular imports, we need to import this here.
66         from django.db.models.base import ObjectDoesNotExist
67
68         try:
69             if self.field:
70                 params = self.field.get_db_prep_lookup(lookup_type, value)
71                 db_type = self.field.db_type()
72             else:
73                 params = CharField().get_db_prep_lookup(lookup_type, value)
74                 db_type = None
75         except ObjectDoesNotExist:
76             raise EmptyShortCircuit
77
78         return (self.alias, self.col, db_type), params
79
80 class Compiler(object):
81     def __init__(self, query, connection, using):
82         self.query = query
83         self.connection = connection
84         self.using = using
85
86     def results_iter(self):
87         if self.query.select_fields:
88             fields = self.query.select_fields
89         else:
90             fields = self.query.model._meta.fields
91
92         attrlist = [ x.db_column for x in fields if x.db_column ]
93
94         try:
95             vals = self.connection.search_s(
96                 self.query.model.base_dn,
97                 self.query.model.search_scope,
98                 filterstr=self.query._ldap_filter(),
99                 attrlist=attrlist,
100             )
101         except ldap.NO_SUCH_OBJECT:
102             return
103
104         # perform sorting
105         if self.query.extra_order_by:
106             ordering = self.query.extra_order_by
107         elif not self.query.default_ordering:
108             ordering = self.query.order_by
109         else:
110             ordering = self.query.order_by or self.query.model._meta.ordering
111         def cmpvals(x, y):
112             for fieldname in ordering:
113                 if fieldname.startswith('-'):
114                     fieldname = fieldname[1:]
115                     negate = True
116                 else:
117                     negate = False
118                 field = self.query.model._meta.get_field(fieldname)
119                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
120                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
121                 # perform case insensitive comparison
122                 if hasattr(attr_x, 'lower'):
123                     attr_x = attr_x.lower()
124                 if hasattr(attr_y, 'lower'):
125                     attr_y = attr_y.lower()
126                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
127                 if val:
128                     return val
129             return 0
130         vals = sorted(vals, cmp=cmpvals)
131
132         # process results
133         pos = 0
134         for dn, attrs in vals:
135             # FIXME : This is not optimal, we retrieve more results than we need
136             # but there is probably no other options as we can't perform ordering
137             # server side.
138             if (self.query.low_mark and pos < self.query.low_mark) or \
139                (self.query.high_mark is not None and pos >= self.query.high_mark):
140                 pos += 1
141                 continue
142             row = []
143             for field in iter(fields):
144                 if field.attname == 'dn':
145                     row.append(dn)
146                 elif hasattr(field, 'from_ldap'):
147                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
148                 else:
149                     row.append(None)
150             yield row
151             pos += 1
152
153
154 class WhereNode(BaseWhereNode):
155     def add(self, data, connector):
156         if not isinstance(data, (list, tuple)):
157             super(WhereNode, self).add(data, connector)
158             return
159
160         # we replace the native Constraint by our own
161         obj, lookup_type, value = data
162         if hasattr(obj, "process"):
163             obj = Constraint(obj.alias, obj.col, obj.field)
164         super(WhereNode, self).add((obj, lookup_type, value), connector)
165
166     def as_sql(self, qn=None, connection=None):
167         bits = []
168         for item in self.children:
169             if hasattr(item, 'as_sql'):
170                 sql, params = item.as_sql(qn=qn, connection=connection)
171                 bits.append(sql)
172                 continue
173
174             constraint, lookup_type, y, values = item
175             comp = get_lookup_operator(lookup_type)
176             if hasattr(constraint, "col"):
177                 # django 1.2
178                 column = constraint.col
179                 if lookup_type == 'in':
180                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
181                     clause = '(|%s)' % ''.join(equal_bits)
182                 else:
183                     clause = "(%s%s%s)" % (constraint.col, comp, values)
184             else:
185                 # django 1.1
186                 (table, column, db_type) = constraint
187                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
188                 if len(equal_bits) == 1:
189                     clause = equal_bits[0]
190                 else:
191                     clause = '(|%s)' % ''.join(equal_bits)
192
193             if self.negated:
194                 bits.append('(!%s)' % clause)
195             else:
196                 bits.append(clause)
197         if len(bits) == 1:
198             sql_string = bits[0]
199         elif self.connector == AND:
200             sql_string = '(&%s)' % ''.join(bits)
201         elif self.connector == OR:
202             sql_string = '(|%s)' % ''.join(bits)
203         else:
204             raise Exception("Unhandled WHERE connector: %s" % self.connector)
205         return sql_string, []
206
207 class Query(BaseQuery):
208     def __init__(self, *args, **kwargs):
209         super(Query, self).__init__(*args, **kwargs)
210         self.connection = ldapdb.connection
211
212     def _ldap_filter(self):
213         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
214         sql, params = self.where.as_sql()
215         filterstr += sql
216         return '(&%s)' % filterstr
217
218     def get_count(self, using=None):
219         try:
220             vals = ldapdb.connection.search_s(
221                 self.model.base_dn,
222                 self.model.search_scope,
223                 filterstr=self._ldap_filter(),
224                 attrlist=[],
225             )
226         except ldap.NO_SUCH_OBJECT:
227             return 0
228
229         number = len(vals)
230
231         # apply limit and offset
232         number = max(0, number - self.low_mark)
233         if self.high_mark is not None:
234             number = min(number, self.high_mark - self.low_mark)
235
236         return number
237
238     def get_compiler(self, using=None, connection=None):
239         return Compiler(self, ldapdb.connection, using)
240
241     def has_results(self, using):
242         return self.get_count() != 0
243
244     def results_iter(self):
245         "For django 1.1 compatibility"
246         return self.get_compiler().results_iter()
247
248 class QuerySet(BaseQuerySet):
249     def __init__(self, model=None, query=None, using=None):
250         if not query:
251             import inspect
252             spec = inspect.getargspec(BaseQuery.__init__)
253             if len(spec[0]) == 3:
254                 # django 1.2
255                 query = Query(model, WhereNode)
256             else:
257                 # django 1.1
258                 query = Query(model, None, WhereNode)
259         super(QuerySet, self).__init__(model=model, query=query)
260
261     def delete(self):
262         "Bulk deletion."
263         try:
264             vals = ldapdb.connection.search_s(
265                 self.model.base_dn,
266                 self.model.search_scope,
267                 filterstr=self.query._ldap_filter(),
268                 attrlist=[],
269             )
270         except ldap.NO_SUCH_OBJECT:
271             return
272
273         # FIXME : there is probably a more efficient way to do this 
274         for dn, attrs in vals:
275             ldapdb.connection.delete_s(dn)
276