129e5056df58b23009d5eb8939728e6b849dbdc0
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (c) 2009-2010, Bolloré telecom
5 # All rights reserved.
6
7 # See AUTHORS file for a full list of contributors.
8
9 # Redistribution and use in source and binary forms, with or without modification,
10 # are permitted provided that the following conditions are met:
11
12 #     1. Redistributions of source code must retain the above copyright notice, 
13 #        this list of conditions and the following disclaimer.
14 #     
15 #     2. Redistributions in binary form must reproduce the above copyright 
16 #        notice, this list of conditions and the following disclaimer in the
17 #        documentation and/or other materials provided with the distribution.
18
19 #     3. Neither the name of Bolloré telecom nor the names of its contributors
20 #        may be used to endorse or promote products derived from this software
21 #        without specific prior written permission.
22
23 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
24 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
25 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
27 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
30 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 #
34
35 from copy import deepcopy
36 import ldap
37
38 from django.db.models.query import QuerySet as BaseQuerySet
39 from django.db.models.query_utils import Q
40 from django.db.models.sql import Query as BaseQuery
41 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
42
43 import ldapdb
44
45 from ldapdb.models.fields import CharField
46
47 def get_lookup_operator(lookup_type):
48     if lookup_type == 'gte':
49         return '>='
50     elif lookup_type == 'lte':
51         return '<='
52     else:
53         return '='
54
55 class Constraint(BaseConstraint):
56     """
57     An object that can be passed to WhereNode.add() and knows how to
58     pre-process itself prior to including in the WhereNode.
59
60     NOTES: 
61     - we redefine this class, because when self.field is None calls
62     Field().get_db_prep_lookup(), which short-circuits our LDAP-specific code.
63     """
64     def process(self, lookup_type, value, connection):
65         """
66         Returns a tuple of data suitable for inclusion in a WhereNode
67         instance.
68         """
69         # Because of circular imports, we need to import this here.
70         from django.db.models.base import ObjectDoesNotExist
71
72         try:
73             if self.field:
74                 params = self.field.get_db_prep_lookup(lookup_type, value,
75                     connection=connection, prepared=True)
76                 db_type = self.field.db_type()
77             else:
78                 params = CharField().get_db_prep_lookup(lookup_type, value,
79                     connection=connection, prepared=True)
80                 db_type = None
81         except ObjectDoesNotExist:
82             raise EmptyShortCircuit
83
84         return (self.alias, self.col, db_type), params
85
86 class Compiler(object):
87     def __init__(self, query, connection, using):
88         self.query = query
89         self.connection = connection
90         self.using = using
91
92     def results_iter(self):
93         if self.query.select_fields:
94             fields = self.query.select_fields
95         else:
96             fields = self.query.model._meta.fields
97
98         attrlist = [ x.db_column for x in fields if x.db_column ]
99
100         try:
101             vals = self.connection.search_s(
102                 self.query.model.base_dn,
103                 self.query.model.search_scope,
104                 filterstr=self.query._ldap_filter(),
105                 attrlist=attrlist,
106             )
107         except ldap.NO_SUCH_OBJECT:
108             return
109
110         # perform sorting
111         if self.query.extra_order_by:
112             ordering = self.query.extra_order_by
113         elif not self.query.default_ordering:
114             ordering = self.query.order_by
115         else:
116             ordering = self.query.order_by or self.query.model._meta.ordering
117         def cmpvals(x, y):
118             for fieldname in ordering:
119                 if fieldname.startswith('-'):
120                     fieldname = fieldname[1:]
121                     negate = True
122                 else:
123                     negate = False
124                 field = self.query.model._meta.get_field(fieldname)
125                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
126                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
127                 # perform case insensitive comparison
128                 if hasattr(attr_x, 'lower'):
129                     attr_x = attr_x.lower()
130                 if hasattr(attr_y, 'lower'):
131                     attr_y = attr_y.lower()
132                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
133                 if val:
134                     return val
135             return 0
136         vals = sorted(vals, cmp=cmpvals)
137
138         # process results
139         pos = 0
140         for dn, attrs in vals:
141             # FIXME : This is not optimal, we retrieve more results than we need
142             # but there is probably no other options as we can't perform ordering
143             # server side.
144             if (self.query.low_mark and pos < self.query.low_mark) or \
145                (self.query.high_mark is not None and pos >= self.query.high_mark):
146                 pos += 1
147                 continue
148             row = []
149             for field in iter(fields):
150                 if field.attname == 'dn':
151                     row.append(dn)
152                 elif hasattr(field, 'from_ldap'):
153                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
154                 else:
155                     row.append(None)
156             yield row
157             pos += 1
158
159
160 class WhereNode(BaseWhereNode):
161     def add(self, data, connector):
162         if not isinstance(data, (list, tuple)):
163             super(WhereNode, self).add(data, connector)
164             return
165
166         # we replace the native Constraint by our own
167         obj, lookup_type, value = data
168         if hasattr(obj, "process"):
169             obj = Constraint(obj.alias, obj.col, obj.field)
170         super(WhereNode, self).add((obj, lookup_type, value), connector)
171
172     def as_sql(self, qn=None, connection=None):
173         bits = []
174         for item in self.children:
175             if hasattr(item, 'as_sql'):
176                 sql, params = item.as_sql(qn=qn, connection=connection)
177                 bits.append(sql)
178                 continue
179
180             constraint, lookup_type, y, values = item
181             comp = get_lookup_operator(lookup_type)
182             if lookup_type == 'in':
183                 equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
184                 clause = '(|%s)' % ''.join(equal_bits)
185             else:
186                 clause = "(%s%s%s)" % (constraint.col, comp, values)
187
188             bits.append(clause)
189
190         if not len(bits):
191             return '', []
192
193         if len(bits) == 1:
194             sql_string = bits[0]
195         elif self.connector == AND:
196             sql_string = '(&%s)' % ''.join(bits)
197         elif self.connector == OR:
198             sql_string = '(|%s)' % ''.join(bits)
199         else:
200             raise Exception("Unhandled WHERE connector: %s" % self.connector)
201
202         if self.negated:
203             sql_string = ('(!%s)' % sql_string)
204
205         return sql_string, []
206
207 class Query(BaseQuery):
208     def __init__(self, *args, **kwargs):
209         super(Query, self).__init__(*args, **kwargs)
210         self.connection = ldapdb.connection
211
212     def _ldap_filter(self):
213         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
214         sql, params = self.where.as_sql()
215         filterstr += sql
216         return '(&%s)' % filterstr
217
218     def get_count(self, using):
219         try:
220             vals = ldapdb.connection.search_s(
221                 self.model.base_dn,
222                 self.model.search_scope,
223                 filterstr=self._ldap_filter(),
224                 attrlist=[],
225             )
226         except ldap.NO_SUCH_OBJECT:
227             return 0
228
229         number = len(vals)
230
231         # apply limit and offset
232         number = max(0, number - self.low_mark)
233         if self.high_mark is not None:
234             number = min(number, self.high_mark - self.low_mark)
235
236         return number
237
238     def get_compiler(self, using=None, connection=None):
239         return Compiler(self, ldapdb.connection, using)
240
241     def has_results(self, using):
242         return self.get_count(using) != 0
243
244 class QuerySet(BaseQuerySet):
245     def __init__(self, model=None, query=None, using=None):
246         if not query:
247             query = Query(model, WhereNode)
248         super(QuerySet, self).__init__(model=model, query=query, using=using)
249
250     def delete(self):
251         "Bulk deletion."
252         try:
253             vals = ldapdb.connection.search_s(
254                 self.model.base_dn,
255                 self.model.search_scope,
256                 filterstr=self.query._ldap_filter(),
257                 attrlist=[],
258             )
259         except ldap.NO_SUCH_OBJECT:
260             return
261
262         # FIXME : there is probably a more efficient way to do this 
263         for dn, attrs in vals:
264             ldapdb.connection.delete_s(dn)
265