fix slice handling
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 Bolloré telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class Compiler(object):
67     def __init__(self, query, connection, using):
68         self.query = query
69         self.connection = connection
70         self.using = using
71
72     def results_iter(self):
73         query = self.query
74         attrlist = [ x.db_column for x in query.model._meta.local_fields if x.db_column ]
75
76         try:
77             vals = self.connection.search_s(
78                 query.model.base_dn,
79                 ldap.SCOPE_SUBTREE,
80                 filterstr=query._ldap_filter(),
81                 attrlist=attrlist,
82             )
83         except ldap.NO_SUCH_OBJECT:
84             return
85             raise query.model.DoesNotExist
86
87         # perform sorting
88         if query.extra_order_by:
89             ordering = query.extra_order_by
90         elif not query.default_ordering:
91             ordering = query.order_by
92         else:
93             ordering = query.order_by or query.model._meta.ordering
94         def cmpvals(x, y):
95             for fieldname in ordering:
96                 if fieldname.startswith('-'):
97                     fieldname = fieldname[1:]
98                     negate = True
99                 else:
100                     negate = False
101                 field = query.model._meta.get_field(fieldname)
102                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
103                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
104                 # perform case insensitive comparison
105                 if hasattr(attr_x, 'lower'):
106                     attr_x = attr_x.lower()
107                 if hasattr(attr_y, 'lower'):
108                     attr_y = attr_y.lower()
109                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
110                 if val:
111                     return val
112             return 0
113         vals = sorted(vals, cmp=cmpvals)
114
115         # process results
116         pos = 0
117         for dn, attrs in vals:
118             # FIXME : This is not optimal, we retrieve more results than we need
119             # but there is probably no other options as we can't perform ordering
120             # server side.
121             if (self.query.low_mark and pos < self.query.low_mark) or \
122                (self.query.high_mark is not None and pos >= self.query.high_mark):
123                 pos += 1
124                 continue
125             row = []
126             for field in iter(query.model._meta.fields):
127                 if field.attname == 'dn':
128                     row.append(dn)
129                 elif hasattr(field, 'from_ldap'):
130                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
131                 else:
132                     row.append(None)
133             yield row
134             pos += 1
135
136
137 class WhereNode(BaseWhereNode):
138     def add(self, data, connector):
139         if not isinstance(data, (list, tuple)):
140             super(WhereNode, self).add(data, connector)
141             return
142
143         # we replace the native Constraint by our own
144         obj, lookup_type, value = data
145         if hasattr(obj, "process"):
146             obj = Constraint(obj.alias, obj.col, obj.field)
147         super(WhereNode, self).add((obj, lookup_type, value), connector)
148
149     def as_sql(self, qn=None, connection=None):
150         bits = []
151         for item in self.children:
152             if hasattr(item, 'as_sql'):
153                 sql, params = item.as_sql(qn=qn, connection=connection)
154                 bits.append(sql)
155                 continue
156
157             constraint, lookup_type, y, values = item
158             comp = get_lookup_operator(lookup_type)
159             if hasattr(constraint, "col"):
160                 # django 1.2
161                 column = constraint.col
162                 if lookup_type == 'in':
163                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
164                     clause = '(|%s)' % ''.join(equal_bits)
165                 else:
166                     clause = "(%s%s%s)" % (constraint.col, comp, values)
167             else:
168                 # django 1.1
169                 (table, column, db_type) = constraint
170                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
171                 if len(equal_bits) == 1:
172                     clause = equal_bits[0]
173                 else:
174                     clause = '(|%s)' % ''.join(equal_bits)
175
176             if self.negated:
177                 bits.append('(!%s)' % clause)
178             else:
179                 bits.append(clause)
180         if len(bits) == 1:
181             sql_string = bits[0]
182         elif self.connector == AND:
183             sql_string = '(&%s)' % ''.join(bits)
184         elif self.connector == OR:
185             sql_string = '(|%s)' % ''.join(bits)
186         else:
187             raise Exception("Unhandled WHERE connector: %s" % self.connector)
188         return sql_string, []
189
190 class Query(BaseQuery):
191     def __init__(self, *args, **kwargs):
192         super(Query, self).__init__(*args, **kwargs)
193         self.connection = ldapdb.connection
194
195     def _ldap_filter(self):
196         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
197         sql, params = self.where.as_sql()
198         filterstr += sql
199         return '(&%s)' % filterstr
200
201     def get_count(self, using=None):
202         try:
203             vals = ldapdb.connection.search_s(
204                 self.model.base_dn,
205                 ldap.SCOPE_SUBTREE,
206                 filterstr=self._ldap_filter(),
207                 attrlist=[],
208             )
209         except ldap.NO_SUCH_OBJECT:
210             return 0
211         return len(vals)
212
213     def get_compiler(self, using=None, connection=None):
214         return Compiler(self, ldapdb.connection, using)
215
216     def results_iter(self):
217         "For django 1.1 compatibility"
218         return self.get_compiler().results_iter()
219
220 class QuerySet(BaseQuerySet):
221     def __init__(self, model=None, query=None, using=None):
222         if not query:
223             import inspect
224             spec = inspect.getargspec(BaseQuery.__init__)
225             if len(spec[0]) == 3:
226                 # django 1.2
227                 query = Query(model, WhereNode)
228             else:
229                 # django 1.1
230                 query = Query(model, None, WhereNode)
231         super(QuerySet, self).__init__(model=model, query=query)
232
233     def delete(self):
234         "Bulk deletion."
235         try:
236             vals = ldapdb.connection.search_s(
237                 self.model.base_dn,
238                 ldap.SCOPE_SUBTREE,
239                 filterstr=self.query._ldap_filter(),
240                 attrlist=[],
241             )
242         except ldap.NO_SUCH_OBJECT:
243             return
244
245         # FIXME : there is probably a more efficient way to do this 
246         for dn, attrs in vals:
247             ldapdb.connection.delete_s(dn)
248