restrict selected fields
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 Bolloré telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class Compiler(object):
67     def __init__(self, query, connection, using):
68         self.query = query
69         self.connection = connection
70         self.using = using
71
72     def results_iter(self):
73         query = self.query
74         if self.query.select_fields:
75             fields = self.query.select_fields
76         else:
77             fields = self.query.model._meta.fields
78
79         attrlist = [ x.db_column for x in fields if x.db_column ]
80
81         try:
82             vals = self.connection.search_s(
83                 query.model.base_dn,
84                 ldap.SCOPE_SUBTREE,
85                 filterstr=query._ldap_filter(),
86                 attrlist=attrlist,
87             )
88         except ldap.NO_SUCH_OBJECT:
89             return
90             raise query.model.DoesNotExist
91
92         # perform sorting
93         if query.extra_order_by:
94             ordering = query.extra_order_by
95         elif not query.default_ordering:
96             ordering = query.order_by
97         else:
98             ordering = query.order_by or query.model._meta.ordering
99         def cmpvals(x, y):
100             for fieldname in ordering:
101                 if fieldname.startswith('-'):
102                     fieldname = fieldname[1:]
103                     negate = True
104                 else:
105                     negate = False
106                 field = query.model._meta.get_field(fieldname)
107                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
108                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
109                 # perform case insensitive comparison
110                 if hasattr(attr_x, 'lower'):
111                     attr_x = attr_x.lower()
112                 if hasattr(attr_y, 'lower'):
113                     attr_y = attr_y.lower()
114                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
115                 if val:
116                     return val
117             return 0
118         vals = sorted(vals, cmp=cmpvals)
119
120         # process results
121         pos = 0
122         for dn, attrs in vals:
123             # FIXME : This is not optimal, we retrieve more results than we need
124             # but there is probably no other options as we can't perform ordering
125             # server side.
126             if (self.query.low_mark and pos < self.query.low_mark) or \
127                (self.query.high_mark is not None and pos >= self.query.high_mark):
128                 pos += 1
129                 continue
130             row = []
131             for field in iter(fields):
132                 if field.attname == 'dn':
133                     row.append(dn)
134                 elif hasattr(field, 'from_ldap'):
135                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
136                 else:
137                     row.append(None)
138             yield row
139             pos += 1
140
141
142 class WhereNode(BaseWhereNode):
143     def add(self, data, connector):
144         if not isinstance(data, (list, tuple)):
145             super(WhereNode, self).add(data, connector)
146             return
147
148         # we replace the native Constraint by our own
149         obj, lookup_type, value = data
150         if hasattr(obj, "process"):
151             obj = Constraint(obj.alias, obj.col, obj.field)
152         super(WhereNode, self).add((obj, lookup_type, value), connector)
153
154     def as_sql(self, qn=None, connection=None):
155         bits = []
156         for item in self.children:
157             if hasattr(item, 'as_sql'):
158                 sql, params = item.as_sql(qn=qn, connection=connection)
159                 bits.append(sql)
160                 continue
161
162             constraint, lookup_type, y, values = item
163             comp = get_lookup_operator(lookup_type)
164             if hasattr(constraint, "col"):
165                 # django 1.2
166                 column = constraint.col
167                 if lookup_type == 'in':
168                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
169                     clause = '(|%s)' % ''.join(equal_bits)
170                 else:
171                     clause = "(%s%s%s)" % (constraint.col, comp, values)
172             else:
173                 # django 1.1
174                 (table, column, db_type) = constraint
175                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
176                 if len(equal_bits) == 1:
177                     clause = equal_bits[0]
178                 else:
179                     clause = '(|%s)' % ''.join(equal_bits)
180
181             if self.negated:
182                 bits.append('(!%s)' % clause)
183             else:
184                 bits.append(clause)
185         if len(bits) == 1:
186             sql_string = bits[0]
187         elif self.connector == AND:
188             sql_string = '(&%s)' % ''.join(bits)
189         elif self.connector == OR:
190             sql_string = '(|%s)' % ''.join(bits)
191         else:
192             raise Exception("Unhandled WHERE connector: %s" % self.connector)
193         return sql_string, []
194
195 class Query(BaseQuery):
196     def __init__(self, *args, **kwargs):
197         super(Query, self).__init__(*args, **kwargs)
198         self.connection = ldapdb.connection
199
200     def _ldap_filter(self):
201         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
202         sql, params = self.where.as_sql()
203         filterstr += sql
204         return '(&%s)' % filterstr
205
206     def get_count(self, using=None):
207         try:
208             vals = ldapdb.connection.search_s(
209                 self.model.base_dn,
210                 ldap.SCOPE_SUBTREE,
211                 filterstr=self._ldap_filter(),
212                 attrlist=[],
213             )
214         except ldap.NO_SUCH_OBJECT:
215             return 0
216         return len(vals)
217
218     def get_compiler(self, using=None, connection=None):
219         return Compiler(self, ldapdb.connection, using)
220
221     def results_iter(self):
222         "For django 1.1 compatibility"
223         return self.get_compiler().results_iter()
224
225 class QuerySet(BaseQuerySet):
226     def __init__(self, model=None, query=None, using=None):
227         if not query:
228             import inspect
229             spec = inspect.getargspec(BaseQuery.__init__)
230             if len(spec[0]) == 3:
231                 # django 1.2
232                 query = Query(model, WhereNode)
233             else:
234                 # django 1.1
235                 query = Query(model, None, WhereNode)
236         super(QuerySet, self).__init__(model=model, query=query)
237
238     def delete(self):
239         "Bulk deletion."
240         try:
241             vals = ldapdb.connection.search_s(
242                 self.model.base_dn,
243                 ldap.SCOPE_SUBTREE,
244                 filterstr=self.query._ldap_filter(),
245                 attrlist=[],
246             )
247         except ldap.NO_SUCH_OBJECT:
248             return
249
250         # FIXME : there is probably a more efficient way to do this 
251         for dn, attrs in vals:
252             ldapdb.connection.delete_s(dn)
253