07a6312bc182f2e6d3cb240f8198b204758df9e8
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 Bollor√© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class Compiler(object):
67     def __init__(self, query, connection, using):
68         self.query = query
69         self.connection = connection
70         self.using = using
71
72     def results_iter(self):
73         if self.query.select_fields:
74             fields = self.query.select_fields
75         else:
76             fields = self.query.model._meta.fields
77
78         attrlist = [ x.db_column for x in fields if x.db_column ]
79
80         try:
81             vals = self.connection.search_s(
82                 self.query.model.base_dn,
83                 ldap.SCOPE_SUBTREE,
84                 filterstr=self.query._ldap_filter(),
85                 attrlist=attrlist,
86             )
87         except ldap.NO_SUCH_OBJECT:
88             return
89
90         # perform sorting
91         if self.query.extra_order_by:
92             ordering = self.query.extra_order_by
93         elif not self.query.default_ordering:
94             ordering = self.query.order_by
95         else:
96             ordering = self.query.order_by or self.query.model._meta.ordering
97         def cmpvals(x, y):
98             for fieldname in ordering:
99                 if fieldname.startswith('-'):
100                     fieldname = fieldname[1:]
101                     negate = True
102                 else:
103                     negate = False
104                 field = self.query.model._meta.get_field(fieldname)
105                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
106                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
107                 # perform case insensitive comparison
108                 if hasattr(attr_x, 'lower'):
109                     attr_x = attr_x.lower()
110                 if hasattr(attr_y, 'lower'):
111                     attr_y = attr_y.lower()
112                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
113                 if val:
114                     return val
115             return 0
116         vals = sorted(vals, cmp=cmpvals)
117
118         # process results
119         pos = 0
120         for dn, attrs in vals:
121             # FIXME : This is not optimal, we retrieve more results than we need
122             # but there is probably no other options as we can't perform ordering
123             #¬†server side.
124             if (self.query.low_mark and pos < self.query.low_mark) or \
125                (self.query.high_mark is not None and pos >= self.query.high_mark):
126                 pos += 1
127                 continue
128             row = []
129             for field in iter(fields):
130                 if field.attname == 'dn':
131                     row.append(dn)
132                 elif hasattr(field, 'from_ldap'):
133                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
134                 else:
135                     row.append(None)
136             yield row
137             pos += 1
138
139
140 class WhereNode(BaseWhereNode):
141     def add(self, data, connector):
142         if not isinstance(data, (list, tuple)):
143             super(WhereNode, self).add(data, connector)
144             return
145
146         # we replace the native Constraint by our own
147         obj, lookup_type, value = data
148         if hasattr(obj, "process"):
149             obj = Constraint(obj.alias, obj.col, obj.field)
150         super(WhereNode, self).add((obj, lookup_type, value), connector)
151
152     def as_sql(self, qn=None, connection=None):
153         bits = []
154         for item in self.children:
155             if hasattr(item, 'as_sql'):
156                 sql, params = item.as_sql(qn=qn, connection=connection)
157                 bits.append(sql)
158                 continue
159
160             constraint, lookup_type, y, values = item
161             comp = get_lookup_operator(lookup_type)
162             if hasattr(constraint, "col"):
163                 # django 1.2
164                 column = constraint.col
165                 if lookup_type == 'in':
166                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
167                     clause = '(|%s)' % ''.join(equal_bits)
168                 else:
169                     clause = "(%s%s%s)" % (constraint.col, comp, values)
170             else:
171                 # django 1.1
172                 (table, column, db_type) = constraint
173                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
174                 if len(equal_bits) == 1:
175                     clause = equal_bits[0]
176                 else:
177                     clause = '(|%s)' % ''.join(equal_bits)
178
179             if self.negated:
180                 bits.append('(!%s)' % clause)
181             else:
182                 bits.append(clause)
183         if len(bits) == 1:
184             sql_string = bits[0]
185         elif self.connector == AND:
186             sql_string = '(&%s)' % ''.join(bits)
187         elif self.connector == OR:
188             sql_string = '(|%s)' % ''.join(bits)
189         else:
190             raise Exception("Unhandled WHERE connector: %s" % self.connector)
191         return sql_string, []
192
193 class Query(BaseQuery):
194     def __init__(self, *args, **kwargs):
195         super(Query, self).__init__(*args, **kwargs)
196         self.connection = ldapdb.connection
197
198     def _ldap_filter(self):
199         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
200         sql, params = self.where.as_sql()
201         filterstr += sql
202         return '(&%s)' % filterstr
203
204     def get_count(self, using=None):
205         try:
206             vals = ldapdb.connection.search_s(
207                 self.model.base_dn,
208                 ldap.SCOPE_SUBTREE,
209                 filterstr=self._ldap_filter(),
210                 attrlist=[],
211             )
212         except ldap.NO_SUCH_OBJECT:
213             return 0
214         return len(vals)
215
216     def get_compiler(self, using=None, connection=None):
217         return Compiler(self, ldapdb.connection, using)
218
219     def results_iter(self):
220         "For django 1.1 compatibility"
221         return self.get_compiler().results_iter()
222
223 class QuerySet(BaseQuerySet):
224     def __init__(self, model=None, query=None, using=None):
225         if not query:
226             import inspect
227             spec = inspect.getargspec(BaseQuery.__init__)
228             if len(spec[0]) == 3:
229                 # django 1.2
230                 query = Query(model, WhereNode)
231             else:
232                 # django 1.1
233                 query = Query(model, None, WhereNode)
234         super(QuerySet, self).__init__(model=model, query=query)
235
236     def delete(self):
237         "Bulk deletion."
238         try:
239             vals = ldapdb.connection.search_s(
240                 self.model.base_dn,
241                 ldap.SCOPE_SUBTREE,
242                 filterstr=self.query._ldap_filter(),
243                 attrlist=[],
244             )
245         except ldap.NO_SUCH_OBJECT:
246             return
247
248         # FIXME : there is probably a more efficient way to do this 
249         for dn, attrs in vals:
250             ldapdb.connection.delete_s(dn)
251