implement and test bulk deletion
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class Compiler(object):
67     def __init__(self, query, connection, using):
68         self.query = query
69         self.connection = connection
70         self.using = using
71
72     def results_iter(self):
73         query = self.query
74         attrlist = [ x.db_column for x in query.model._meta.local_fields if x.db_column ]
75
76         vals = self.connection.search_s(
77             query.model.base_dn,
78             ldap.SCOPE_SUBTREE,
79             filterstr=query._ldap_filter(),
80             attrlist=attrlist,
81         )
82
83         # perform sorting
84         if query.extra_order_by:
85             ordering = query.extra_order_by
86         elif not query.default_ordering:
87             ordering = query.order_by
88         else:
89             ordering = query.order_by or query.model._meta.ordering
90         def cmpvals(x, y):
91             for fieldname in ordering:
92                 if fieldname.startswith('-'):
93                     fieldname = fieldname[1:]
94                     negate = True
95                 else:
96                     negate = False
97                 field = query.model._meta.get_field(fieldname)
98                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
99                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
100                 # perform case insensitive comparison
101                 if hasattr(attr_x, 'lower'):
102                     attr_x = attr_x.lower()
103                 if hasattr(attr_y, 'lower'):
104                     attr_y = attr_y.lower()
105                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
106                 if val:
107                     return val
108             return 0
109         vals = sorted(vals, cmp=cmpvals)
110
111         # process results
112         for dn, attrs in vals:
113             row = []
114             for field in iter(query.model._meta.fields):
115                 if field.attname == 'dn':
116                     row.append(dn)
117                 elif hasattr(field, 'from_ldap'):
118                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
119                 else:
120                     row.append(None)
121             yield row
122
123
124 class WhereNode(BaseWhereNode):
125     def add(self, data, connector):
126         if not isinstance(data, (list, tuple)):
127             super(WhereNode, self).add(data, connector)
128             return
129
130         # we replace the native Constraint by our own
131         obj, lookup_type, value = data
132         if hasattr(obj, "process"):
133             obj = Constraint(obj.alias, obj.col, obj.field)
134         super(WhereNode, self).add((obj, lookup_type, value), connector)
135
136     def as_sql(self, qn=None, connection=None):
137         bits = []
138         for item in self.children:
139             if hasattr(item, 'as_sql'):
140                 sql, params = item.as_sql(qn=qn, connection=connection)
141                 bits.append(sql)
142                 continue
143
144             constraint, lookup_type, y, values = item
145             comp = get_lookup_operator(lookup_type)
146             if hasattr(constraint, "col"):
147                 # django 1.2
148                 column = constraint.col
149                 if lookup_type == 'in':
150                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
151                     clause = '(|%s)' % ''.join(equal_bits)
152                 else:
153                     clause = "(%s%s%s)" % (constraint.col, comp, values)
154             else:
155                 # django 1.1
156                 (table, column, db_type) = constraint
157                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
158                 if len(equal_bits) == 1:
159                     clause = equal_bits[0]
160                 else:
161                     clause = '(|%s)' % ''.join(equal_bits)
162
163             if self.negated:
164                 bits.append('(!%s)' % clause)
165             else:
166                 bits.append(clause)
167         if len(bits) == 1:
168             sql_string = bits[0]
169         elif self.connector == AND:
170             sql_string = '(&%s)' % ''.join(bits)
171         elif self.connector == OR:
172             sql_string = '(|%s)' % ''.join(bits)
173         else:
174             raise Exception("Unhandled WHERE connector: %s" % self.connector)
175         return sql_string, []
176
177 class Query(BaseQuery):
178     def __init__(self, *args, **kwargs):
179         super(Query, self).__init__(*args, **kwargs)
180         self.connection = ldapdb.connection
181
182     def _ldap_filter(self):
183         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
184         sql, params = self.where.as_sql()
185         filterstr += sql
186         return '(&%s)' % filterstr
187
188     def get_count(self, using=None):
189         vals = ldapdb.connection.search_s(
190             self.model.base_dn,
191             ldap.SCOPE_SUBTREE,
192             filterstr=self._ldap_filter(),
193             attrlist=[],
194         )
195         return len(vals)
196
197     def get_compiler(self, using=None, connection=None):
198         return Compiler(self, ldapdb.connection, using)
199
200     def results_iter(self):
201         "For django 1.1 compatibility"
202         return self.get_compiler().results_iter()
203
204 class QuerySet(BaseQuerySet):
205     def __init__(self, model=None, query=None, using=None):
206         if not query:
207             import inspect
208             spec = inspect.getargspec(BaseQuery.__init__)
209             if len(spec[0]) == 3:
210                 # django 1.2
211                 query = Query(model, WhereNode)
212             else:
213                 # django 1.1
214                 query = Query(model, None, WhereNode)
215         super(QuerySet, self).__init__(model=model, query=query)
216
217     def delete(self):
218         "Bulk deletion."
219         vals = ldapdb.connection.search_s(
220             self.model.base_dn,
221             ldap.SCOPE_SUBTREE,
222             filterstr=self.query._ldap_filter(),
223             attrlist=[],
224         )
225         # FIXME : there is probably a more efficient way to do this 
226         for dn, attrs in vals:
227             ldapdb.connection.delete_s(dn)
228