catch ldap.NO_SUCH_OBJECT for calls to ldap.search_s
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class Compiler(object):
67     def __init__(self, query, connection, using):
68         self.query = query
69         self.connection = connection
70         self.using = using
71
72     def results_iter(self):
73         query = self.query
74         attrlist = [ x.db_column for x in query.model._meta.local_fields if x.db_column ]
75
76         try:
77             vals = self.connection.search_s(
78                 query.model.base_dn,
79                 ldap.SCOPE_SUBTREE,
80                 filterstr=query._ldap_filter(),
81                 attrlist=attrlist,
82             )
83         except ldap.NO_SUCH_OBJECT:
84             return
85             raise query.model.DoesNotExist
86
87         # perform sorting
88         if query.extra_order_by:
89             ordering = query.extra_order_by
90         elif not query.default_ordering:
91             ordering = query.order_by
92         else:
93             ordering = query.order_by or query.model._meta.ordering
94         def cmpvals(x, y):
95             for fieldname in ordering:
96                 if fieldname.startswith('-'):
97                     fieldname = fieldname[1:]
98                     negate = True
99                 else:
100                     negate = False
101                 field = query.model._meta.get_field(fieldname)
102                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
103                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
104                 # perform case insensitive comparison
105                 if hasattr(attr_x, 'lower'):
106                     attr_x = attr_x.lower()
107                 if hasattr(attr_y, 'lower'):
108                     attr_y = attr_y.lower()
109                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
110                 if val:
111                     return val
112             return 0
113         vals = sorted(vals, cmp=cmpvals)
114
115         # process results
116         for dn, attrs in vals:
117             row = []
118             for field in iter(query.model._meta.fields):
119                 if field.attname == 'dn':
120                     row.append(dn)
121                 elif hasattr(field, 'from_ldap'):
122                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
123                 else:
124                     row.append(None)
125             yield row
126
127
128 class WhereNode(BaseWhereNode):
129     def add(self, data, connector):
130         if not isinstance(data, (list, tuple)):
131             super(WhereNode, self).add(data, connector)
132             return
133
134         # we replace the native Constraint by our own
135         obj, lookup_type, value = data
136         if hasattr(obj, "process"):
137             obj = Constraint(obj.alias, obj.col, obj.field)
138         super(WhereNode, self).add((obj, lookup_type, value), connector)
139
140     def as_sql(self, qn=None, connection=None):
141         bits = []
142         for item in self.children:
143             if hasattr(item, 'as_sql'):
144                 sql, params = item.as_sql(qn=qn, connection=connection)
145                 bits.append(sql)
146                 continue
147
148             constraint, lookup_type, y, values = item
149             comp = get_lookup_operator(lookup_type)
150             if hasattr(constraint, "col"):
151                 # django 1.2
152                 column = constraint.col
153                 if lookup_type == 'in':
154                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
155                     clause = '(|%s)' % ''.join(equal_bits)
156                 else:
157                     clause = "(%s%s%s)" % (constraint.col, comp, values)
158             else:
159                 # django 1.1
160                 (table, column, db_type) = constraint
161                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
162                 if len(equal_bits) == 1:
163                     clause = equal_bits[0]
164                 else:
165                     clause = '(|%s)' % ''.join(equal_bits)
166
167             if self.negated:
168                 bits.append('(!%s)' % clause)
169             else:
170                 bits.append(clause)
171         if len(bits) == 1:
172             sql_string = bits[0]
173         elif self.connector == AND:
174             sql_string = '(&%s)' % ''.join(bits)
175         elif self.connector == OR:
176             sql_string = '(|%s)' % ''.join(bits)
177         else:
178             raise Exception("Unhandled WHERE connector: %s" % self.connector)
179         return sql_string, []
180
181 class Query(BaseQuery):
182     def __init__(self, *args, **kwargs):
183         super(Query, self).__init__(*args, **kwargs)
184         self.connection = ldapdb.connection
185
186     def _ldap_filter(self):
187         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
188         sql, params = self.where.as_sql()
189         filterstr += sql
190         return '(&%s)' % filterstr
191
192     def get_count(self, using=None):
193         try:
194             vals = ldapdb.connection.search_s(
195                 self.model.base_dn,
196                 ldap.SCOPE_SUBTREE,
197                 filterstr=self._ldap_filter(),
198                 attrlist=[],
199             )
200         except ldap.NO_SUCH_OBJECT:
201             return 0
202         return len(vals)
203
204     def get_compiler(self, using=None, connection=None):
205         return Compiler(self, ldapdb.connection, using)
206
207     def results_iter(self):
208         "For django 1.1 compatibility"
209         return self.get_compiler().results_iter()
210
211 class QuerySet(BaseQuerySet):
212     def __init__(self, model=None, query=None, using=None):
213         if not query:
214             import inspect
215             spec = inspect.getargspec(BaseQuery.__init__)
216             if len(spec[0]) == 3:
217                 # django 1.2
218                 query = Query(model, WhereNode)
219             else:
220                 # django 1.1
221                 query = Query(model, None, WhereNode)
222         super(QuerySet, self).__init__(model=model, query=query)
223
224     def delete(self):
225         "Bulk deletion."
226         try:
227             vals = ldapdb.connection.search_s(
228                 self.model.base_dn,
229                 ldap.SCOPE_SUBTREE,
230                 filterstr=self.query._ldap_filter(),
231                 attrlist=[],
232             )
233         except ldap.NO_SUCH_OBJECT:
234             return
235
236         # FIXME : there is probably a more efficient way to do this 
237         for dn, attrs in vals:
238             ldapdb.connection.delete_s(dn)
239