make it possible to use search in admin interface
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class WhereNode(BaseWhereNode):
67     def add(self, data, connector):
68         if not isinstance(data, (list, tuple)):
69             super(WhereNode, self).add(data, connector)
70             return
71
72         # we replace the native Constraint by our own
73         obj, lookup_type, value = data
74         if hasattr(obj, "process"):
75             obj = Constraint(obj.alias, obj.col, obj.field)
76         super(WhereNode, self).add((obj, lookup_type, value), connector)
77
78     def as_sql(self, qn=None):
79         bits = []
80         for item in self.children:
81             if isinstance(item, WhereNode):
82                 sql, params = item.as_sql()
83                 bits.append(sql)
84                 continue
85
86             constraint, lookup_type, y, values = item
87             comp = get_lookup_operator(lookup_type)
88             if hasattr(constraint, "col"):
89                 # django 1.2
90                 column = constraint.col
91                 if lookup_type == 'in':
92                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
93                     clause = '(|%s)' % ''.join(equal_bits)
94                 else:
95                     clause = "(%s%s%s)" % (constraint.col, comp, values)
96             else:
97                 # django 1.1
98                 (table, column, db_type) = constraint
99                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
100                 if len(equal_bits) == 1:
101                     clause = equal_bits[0]
102                 else:
103                     clause = '(|%s)' % ''.join(equal_bits)
104
105             if self.negated:
106                 bits.append('(!%s)' % clause)
107             else:
108                 bits.append(clause)
109         if len(bits) == 1:
110             sql_string = bits[0]
111         elif self.connector == AND:
112             sql_string = '(&%s)' % ''.join(bits)
113         elif self.connector == OR:
114             sql_string = '(|%s)' % ''.join(bits)
115         else:
116             raise Exception("Unhandled WHERE connector: %s" % self.connector)
117         return sql_string, []
118
119 class Query(BaseQuery):
120     def __init__(self, *args, **kwargs):
121         super(Query, self).__init__(*args, **kwargs)
122         self.connection = ldapdb.connection
123
124     def get_count(self):
125         # FIXME: use all object classes
126         filterstr = '(objectClass=%s)' % self.model.object_classes[0]
127         sql, params = self.where.as_sql()
128         filterstr += sql
129         filterstr = '(&%s)' % filterstr
130
131         try:
132             vals = ldapdb.connection.search_s(
133                 self.model.base_dn,
134                 ldap.SCOPE_SUBTREE,
135                 filterstr=filterstr,
136                 attrlist=[],
137             )
138         except:
139             raise self.model.DoesNotExist
140
141         return len(vals)
142
143     def results_iter(self):
144         # FIXME: use all object classes
145         filterstr = '(objectClass=%s)' % self.model.object_classes[0]
146         sql, params = self.where.as_sql()
147         filterstr += sql
148         filterstr = '(&%s)' % filterstr
149         attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
150
151         try:
152             vals = ldapdb.connection.search_s(
153                 self.model.base_dn,
154                 ldap.SCOPE_SUBTREE,
155                 filterstr=filterstr,
156                 attrlist=attrlist,
157             )
158         except:
159             raise self.model.DoesNotExist
160
161         # perform sorting
162         if self.extra_order_by:
163             ordering = self.extra_order_by
164         elif not self.default_ordering:
165             ordering = self.order_by
166         else:
167             ordering = self.order_by or self.model._meta.ordering
168         def cmpvals(x, y):
169             for field in ordering:
170                 if field.startswith('-'):
171                     field = field[1:]
172                     negate = True
173                 else:
174                     negate = False
175                 attr = self.model._meta.get_field(field).db_column
176                 attr_x = x[1].get(attr, '').lower()
177                 attr_y = y[1].get(attr, '').lower()
178                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
179                 if val:
180                     return val
181             return 0
182         vals = sorted(vals, cmp=cmpvals)
183
184         # process results
185         for dn, attrs in vals:
186             row = []
187             for field in iter(self.model._meta.fields):
188                 if field.attname == 'dn':
189                     row.append(dn)
190                 else:
191                     row.append(attrs.get(field.db_column, None))
192             yield row
193
194 class QuerySet(BaseQuerySet):
195     def __init__(self, model=None, query=None, using=None):
196         if not query:
197             import inspect
198             spec = inspect.getargspec(BaseQuery.__init__)
199             if len(spec[0]) == 3:
200                 # django 1.2
201                 query = Query(model, WhereNode)
202             else:
203                 # django 1.1
204                 query = Query(model, None, WhereNode)
205         super(QuerySet, self).__init__(model=model, query=query)
206