91a3f62e2058bf8587342559c04ffefee3c58487
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (c) 2009-2010, Bolloré telecom
5 # All rights reserved.
6
7 # See AUTHORS file for a full list of contributors.
8
9 # Redistribution and use in source and binary forms, with or without modification,
10 # are permitted provided that the following conditions are met:
11
12 #     1. Redistributions of source code must retain the above copyright notice, 
13 #        this list of conditions and the following disclaimer.
14 #     
15 #     2. Redistributions in binary form must reproduce the above copyright 
16 #        notice, this list of conditions and the following disclaimer in the
17 #        documentation and/or other materials provided with the distribution.
18
19 #     3. Neither the name of Bolloré telecom nor the names of its contributors
20 #        may be used to endorse or promote products derived from this software
21 #        without specific prior written permission.
22
23 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
24 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
25 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
27 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
30 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 #
34
35 from copy import deepcopy
36 import ldap
37
38 from django.db.models.query import QuerySet as BaseQuerySet
39 from django.db.models.query_utils import Q
40 from django.db.models.sql import Query as BaseQuery
41 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
42
43 import ldapdb
44
45 from ldapdb.models.fields import CharField
46
47 def get_lookup_operator(lookup_type):
48     if lookup_type == 'gte':
49         return '>='
50     elif lookup_type == 'lte':
51         return '<='
52     else:
53         return '='
54
55 class Constraint(BaseConstraint):
56     """
57     An object that can be passed to WhereNode.add() and knows how to
58     pre-process itself prior to including in the WhereNode.
59
60     NOTES: 
61     - we redefine this class, because when self.field is None calls
62     Field().get_db_prep_lookup(), which short-circuits our LDAP-specific code.
63     - the connection argument defaults to None for django 1.1 compatibility
64     """
65     def process(self, lookup_type, value, connection=None):
66         """
67         Returns a tuple of data suitable for inclusion in a WhereNode
68         instance.
69         """
70         # Because of circular imports, we need to import this here.
71         from django.db.models.base import ObjectDoesNotExist
72
73         try:
74             if self.field:
75                 params = self.field.get_db_prep_lookup(lookup_type, value,
76                     connection=connection, prepared=True)
77                 db_type = self.field.db_type()
78             else:
79                 params = CharField().get_db_prep_lookup(lookup_type, value,
80                     connection=connection, prepared=True)
81                 db_type = None
82         except ObjectDoesNotExist:
83             raise EmptyShortCircuit
84
85         return (self.alias, self.col, db_type), params
86
87 class Compiler(object):
88     def __init__(self, query, connection, using):
89         self.query = query
90         self.connection = connection
91         self.using = using
92
93     def results_iter(self):
94         if self.query.select_fields:
95             fields = self.query.select_fields
96         else:
97             fields = self.query.model._meta.fields
98
99         attrlist = [ x.db_column for x in fields if x.db_column ]
100
101         try:
102             vals = self.connection.search_s(
103                 self.query.model.base_dn,
104                 self.query.model.search_scope,
105                 filterstr=self.query._ldap_filter(),
106                 attrlist=attrlist,
107             )
108         except ldap.NO_SUCH_OBJECT:
109             return
110
111         # perform sorting
112         if self.query.extra_order_by:
113             ordering = self.query.extra_order_by
114         elif not self.query.default_ordering:
115             ordering = self.query.order_by
116         else:
117             ordering = self.query.order_by or self.query.model._meta.ordering
118         def cmpvals(x, y):
119             for fieldname in ordering:
120                 if fieldname.startswith('-'):
121                     fieldname = fieldname[1:]
122                     negate = True
123                 else:
124                     negate = False
125                 field = self.query.model._meta.get_field(fieldname)
126                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
127                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
128                 # perform case insensitive comparison
129                 if hasattr(attr_x, 'lower'):
130                     attr_x = attr_x.lower()
131                 if hasattr(attr_y, 'lower'):
132                     attr_y = attr_y.lower()
133                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
134                 if val:
135                     return val
136             return 0
137         vals = sorted(vals, cmp=cmpvals)
138
139         # process results
140         pos = 0
141         for dn, attrs in vals:
142             # FIXME : This is not optimal, we retrieve more results than we need
143             # but there is probably no other options as we can't perform ordering
144             # server side.
145             if (self.query.low_mark and pos < self.query.low_mark) or \
146                (self.query.high_mark is not None and pos >= self.query.high_mark):
147                 pos += 1
148                 continue
149             row = []
150             for field in iter(fields):
151                 if field.attname == 'dn':
152                     row.append(dn)
153                 elif hasattr(field, 'from_ldap'):
154                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
155                 else:
156                     row.append(None)
157             yield row
158             pos += 1
159
160
161 class WhereNode(BaseWhereNode):
162     def add(self, data, connector):
163         if not isinstance(data, (list, tuple)):
164             super(WhereNode, self).add(data, connector)
165             return
166
167         # we replace the native Constraint by our own
168         obj, lookup_type, value = data
169         if hasattr(obj, "process"):
170             obj = Constraint(obj.alias, obj.col, obj.field)
171         super(WhereNode, self).add((obj, lookup_type, value), connector)
172
173     def as_sql(self, qn=None, connection=None):
174         bits = []
175         for item in self.children:
176             if hasattr(item, 'as_sql'):
177                 sql, params = item.as_sql(qn=qn, connection=connection)
178                 bits.append(sql)
179                 continue
180
181             constraint, lookup_type, y, values = item
182             comp = get_lookup_operator(lookup_type)
183             if hasattr(constraint, "col"):
184                 # django 1.2
185                 column = constraint.col
186                 if lookup_type == 'in':
187                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
188                     clause = '(|%s)' % ''.join(equal_bits)
189                 else:
190                     clause = "(%s%s%s)" % (constraint.col, comp, values)
191             else:
192                 # django 1.1
193                 (table, column, db_type) = constraint
194                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
195                 if len(equal_bits) == 1:
196                     clause = equal_bits[0]
197                 else:
198                     clause = '(|%s)' % ''.join(equal_bits)
199
200             bits.append(clause)
201
202         if not len(bits):
203             return '', []
204
205         if len(bits) == 1:
206             sql_string = bits[0]
207         elif self.connector == AND:
208             sql_string = '(&%s)' % ''.join(bits)
209         elif self.connector == OR:
210             sql_string = '(|%s)' % ''.join(bits)
211         else:
212             raise Exception("Unhandled WHERE connector: %s" % self.connector)
213
214         if self.negated:
215             sql_string = ('(!%s)' % sql_string)
216
217         return sql_string, []
218
219 class Query(BaseQuery):
220     def __init__(self, *args, **kwargs):
221         super(Query, self).__init__(*args, **kwargs)
222         self.connection = ldapdb.connection
223
224     def _ldap_filter(self):
225         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
226         sql, params = self.where.as_sql()
227         filterstr += sql
228         return '(&%s)' % filterstr
229
230     def get_count(self, using=None):
231         try:
232             vals = ldapdb.connection.search_s(
233                 self.model.base_dn,
234                 self.model.search_scope,
235                 filterstr=self._ldap_filter(),
236                 attrlist=[],
237             )
238         except ldap.NO_SUCH_OBJECT:
239             return 0
240
241         number = len(vals)
242
243         # apply limit and offset
244         number = max(0, number - self.low_mark)
245         if self.high_mark is not None:
246             number = min(number, self.high_mark - self.low_mark)
247
248         return number
249
250     def get_compiler(self, using=None, connection=None):
251         return Compiler(self, ldapdb.connection, using)
252
253     def has_results(self, using):
254         return self.get_count() != 0
255
256     def results_iter(self):
257         "For django 1.1 compatibility"
258         return self.get_compiler().results_iter()
259
260 class QuerySet(BaseQuerySet):
261     def __init__(self, model=None, query=None, using=None):
262         if not query:
263             import inspect
264             spec = inspect.getargspec(BaseQuery.__init__)
265             if len(spec[0]) == 3:
266                 # django 1.2
267                 query = Query(model, WhereNode)
268             else:
269                 # django 1.1
270                 query = Query(model, None, WhereNode)
271         super(QuerySet, self).__init__(model=model, query=query)
272
273     def delete(self):
274         "Bulk deletion."
275         try:
276             vals = ldapdb.connection.search_s(
277                 self.model.base_dn,
278                 self.model.search_scope,
279                 filterstr=self.query._ldap_filter(),
280                 attrlist=[],
281             )
282         except ldap.NO_SUCH_OBJECT:
283             return
284
285         # FIXME : there is probably a more efficient way to do this 
286         for dn, attrs in vals:
287             ldapdb.connection.delete_s(dn)
288