accept "using" keyword for get_count
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class Compiler(object):
67     def __init__(self, query, connection, using):
68         self.query = query
69         self.connection = connection
70         self.using = using
71
72     def results_iter(self):
73         query = self.query
74
75         filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes])
76         sql, params = query.where.as_sql()
77         filterstr += sql
78         filterstr = '(&%s)' % filterstr
79         attrlist = [ x.db_column for x in query.model._meta.local_fields if x.db_column ]
80
81         try:
82             vals = self.connection.search_s(
83                 query.model.base_dn,
84                 ldap.SCOPE_SUBTREE,
85                 filterstr=filterstr,
86                 attrlist=attrlist,
87             )
88         except:
89             raise query.model.DoesNotExist
90
91         # perform sorting
92         if query.extra_order_by:
93             ordering = query.extra_order_by
94         elif not query.default_ordering:
95             ordering = query.order_by
96         else:
97             ordering = query.order_by or query.model._meta.ordering
98         def cmpvals(x, y):
99             for fieldname in ordering:
100                 if fieldname.startswith('-'):
101                     fieldname = fieldname[1:]
102                     negate = True
103                 else:
104                     negate = False
105                 field = query.model._meta.get_field(fieldname)
106                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
107                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
108                 # perform case insensitive comparison
109                 if hasattr(attr_x, 'lower'):
110                     attr_x = attr_x.lower()
111                 if hasattr(attr_y, 'lower'):
112                     attr_y = attr_y.lower()
113                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
114                 if val:
115                     return val
116             return 0
117         vals = sorted(vals, cmp=cmpvals)
118
119         # process results
120         for dn, attrs in vals:
121             row = []
122             for field in iter(query.model._meta.fields):
123                 if field.attname == 'dn':
124                     row.append(dn)
125                 elif hasattr(field, 'from_ldap'):
126                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
127                 else:
128                     row.append(None)
129             yield row
130
131
132 class WhereNode(BaseWhereNode):
133     def add(self, data, connector):
134         if not isinstance(data, (list, tuple)):
135             super(WhereNode, self).add(data, connector)
136             return
137
138         # we replace the native Constraint by our own
139         obj, lookup_type, value = data
140         if hasattr(obj, "process"):
141             obj = Constraint(obj.alias, obj.col, obj.field)
142         super(WhereNode, self).add((obj, lookup_type, value), connector)
143
144     def as_sql(self, qn=None, connection=None):
145         bits = []
146         for item in self.children:
147             if hasattr(item, 'as_sql'):
148                 sql, params = item.as_sql(qn=qn, connection=connection)
149                 bits.append(sql)
150                 continue
151
152             constraint, lookup_type, y, values = item
153             comp = get_lookup_operator(lookup_type)
154             if hasattr(constraint, "col"):
155                 # django 1.2
156                 column = constraint.col
157                 if lookup_type == 'in':
158                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
159                     clause = '(|%s)' % ''.join(equal_bits)
160                 else:
161                     clause = "(%s%s%s)" % (constraint.col, comp, values)
162             else:
163                 # django 1.1
164                 (table, column, db_type) = constraint
165                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
166                 if len(equal_bits) == 1:
167                     clause = equal_bits[0]
168                 else:
169                     clause = '(|%s)' % ''.join(equal_bits)
170
171             if self.negated:
172                 bits.append('(!%s)' % clause)
173             else:
174                 bits.append(clause)
175         if len(bits) == 1:
176             sql_string = bits[0]
177         elif self.connector == AND:
178             sql_string = '(&%s)' % ''.join(bits)
179         elif self.connector == OR:
180             sql_string = '(|%s)' % ''.join(bits)
181         else:
182             raise Exception("Unhandled WHERE connector: %s" % self.connector)
183         return sql_string, []
184
185 class Query(BaseQuery):
186     def __init__(self, *args, **kwargs):
187         super(Query, self).__init__(*args, **kwargs)
188         self.connection = ldapdb.connection
189
190     def get_count(self, using=None):
191         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
192         sql, params = self.where.as_sql()
193         filterstr += sql
194         filterstr = '(&%s)' % filterstr
195
196         try:
197             vals = self.connection.search_s(
198                 self.model.base_dn,
199                 ldap.SCOPE_SUBTREE,
200                 filterstr=filterstr,
201                 attrlist=[],
202             )
203         except:
204             raise self.model.DoesNotExist
205
206         return len(vals)
207
208     def get_compiler(self, using=None, connection=None):
209         return Compiler(self, ldapdb.connection, using)
210
211     def results_iter(self):
212         "For django 1.1 compatibility"
213         return self.get_compiler().results_iter()
214
215 class QuerySet(BaseQuerySet):
216     def __init__(self, model=None, query=None, using=None):
217         if not query:
218             import inspect
219             spec = inspect.getargspec(BaseQuery.__init__)
220             if len(spec[0]) == 3:
221                 # django 1.2
222                 query = Query(model, WhereNode)
223             else:
224                 # django 1.1
225                 query = Query(model, None, WhereNode)
226         super(QuerySet, self).__init__(model=model, query=query)
227