d0958b18c4386bb38c019327c293c7d84f8c29d4
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class Compiler(object):
67     def __init__(self, query, connection, using):
68         self.query = query
69         self.connection = connection
70         self.using = using
71
72     def results_iter(self):
73         query = self.query
74
75         filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes])
76         sql, params = query.where.as_sql()
77         filterstr += sql
78         filterstr = '(&%s)' % filterstr
79         attrlist = [ x.db_column for x in query.model._meta.local_fields if x.db_column ]
80
81         try:
82             vals = self.connection.search_s(
83                 query.model.base_dn,
84                 ldap.SCOPE_SUBTREE,
85                 filterstr=filterstr,
86                 attrlist=attrlist,
87             )
88         except:
89             raise query.model.DoesNotExist
90
91         # perform sorting
92         if query.extra_order_by:
93             ordering = query.extra_order_by
94         elif not query.default_ordering:
95             ordering = query.order_by
96         else:
97             ordering = query.order_by or query.model._meta.ordering
98         def cmpvals(x, y):
99             for fieldname in ordering:
100                 if fieldname.startswith('-'):
101                     fieldname = fieldname[1:]
102                     negate = True
103                 else:
104                     negate = False
105                 field = query.model._meta.get_field(fieldname)
106                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection).lower()
107                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection).lower()
108                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
109                 if val:
110                     return val
111             return 0
112         vals = sorted(vals, cmp=cmpvals)
113
114         # process results
115         for dn, attrs in vals:
116             row = []
117             for field in iter(query.model._meta.fields):
118                 if field.attname == 'dn':
119                     row.append(dn)
120                 elif hasattr(field, 'from_ldap'):
121                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
122                 else:
123                     row.append(None)
124             yield row
125
126
127 class WhereNode(BaseWhereNode):
128     def add(self, data, connector):
129         if not isinstance(data, (list, tuple)):
130             super(WhereNode, self).add(data, connector)
131             return
132
133         # we replace the native Constraint by our own
134         obj, lookup_type, value = data
135         if hasattr(obj, "process"):
136             obj = Constraint(obj.alias, obj.col, obj.field)
137         super(WhereNode, self).add((obj, lookup_type, value), connector)
138
139     def as_sql(self, qn=None, connection=None):
140         bits = []
141         for item in self.children:
142             if hasattr(item, 'as_sql'):
143                 sql, params = item.as_sql(qn=qn, connection=connection)
144                 bits.append(sql)
145                 continue
146
147             constraint, lookup_type, y, values = item
148             comp = get_lookup_operator(lookup_type)
149             if hasattr(constraint, "col"):
150                 # django 1.2
151                 column = constraint.col
152                 if lookup_type == 'in':
153                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
154                     clause = '(|%s)' % ''.join(equal_bits)
155                 else:
156                     clause = "(%s%s%s)" % (constraint.col, comp, values)
157             else:
158                 # django 1.1
159                 (table, column, db_type) = constraint
160                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
161                 if len(equal_bits) == 1:
162                     clause = equal_bits[0]
163                 else:
164                     clause = '(|%s)' % ''.join(equal_bits)
165
166             if self.negated:
167                 bits.append('(!%s)' % clause)
168             else:
169                 bits.append(clause)
170         if len(bits) == 1:
171             sql_string = bits[0]
172         elif self.connector == AND:
173             sql_string = '(&%s)' % ''.join(bits)
174         elif self.connector == OR:
175             sql_string = '(|%s)' % ''.join(bits)
176         else:
177             raise Exception("Unhandled WHERE connector: %s" % self.connector)
178         return sql_string, []
179
180 class Query(BaseQuery):
181     def __init__(self, *args, **kwargs):
182         super(Query, self).__init__(*args, **kwargs)
183         self.connection = ldapdb.connection
184
185     def get_count(self):
186         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
187         sql, params = self.where.as_sql()
188         filterstr += sql
189         filterstr = '(&%s)' % filterstr
190
191         try:
192             vals = self.connection.search_s(
193                 self.model.base_dn,
194                 ldap.SCOPE_SUBTREE,
195                 filterstr=filterstr,
196                 attrlist=[],
197             )
198         except:
199             raise self.model.DoesNotExist
200
201         return len(vals)
202
203     def get_compiler(self, using=None, connection=None):
204         return Compiler(self, ldapdb.connection, using)
205
206     def results_iter(self):
207         "For django 1.1 compatibility"
208         return self.get_compiler().results_iter()
209
210 class QuerySet(BaseQuerySet):
211     def __init__(self, model=None, query=None, using=None):
212         if not query:
213             import inspect
214             spec = inspect.getargspec(BaseQuery.__init__)
215             if len(spec[0]) == 3:
216                 # django 1.2
217                 query = Query(model, WhereNode)
218             else:
219                 # django 1.1
220                 query = Query(model, None, WhereNode)
221         super(QuerySet, self).__init__(model=model, query=query)
222