add some FIXMEs
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class WhereNode(BaseWhereNode):
67     def add(self, data, connector):
68         if not isinstance(data, (list, tuple)):
69             super(WhereNode, self).add(data, connector)
70             return
71
72         # we replace the native Constraint by our own
73         obj, lookup_type, value = data
74         if hasattr(obj, "process"):
75             obj = Constraint(obj.alias, obj.col, obj.field)
76         super(WhereNode, self).add((obj, lookup_type, value), connector)
77
78     def as_sql(self, qn=None, connection=None):
79         bits = []
80         for item in self.children:
81             if hasattr(item, 'as_sql'):
82                 sql, params = item.as_sql(qn=qn, connection=connection)
83                 bits.append(sql)
84                 continue
85
86             constraint, lookup_type, y, values = item
87             comp = get_lookup_operator(lookup_type)
88             if hasattr(constraint, "col"):
89                 # django 1.2
90                 column = constraint.col
91                 if lookup_type == 'in':
92                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
93                     clause = '(|%s)' % ''.join(equal_bits)
94                 else:
95                     clause = "(%s%s%s)" % (constraint.col, comp, values)
96             else:
97                 # django 1.1
98                 (table, column, db_type) = constraint
99                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
100                 if len(equal_bits) == 1:
101                     clause = equal_bits[0]
102                 else:
103                     clause = '(|%s)' % ''.join(equal_bits)
104
105             if self.negated:
106                 bits.append('(!%s)' % clause)
107             else:
108                 bits.append(clause)
109         if len(bits) == 1:
110             sql_string = bits[0]
111         elif self.connector == AND:
112             sql_string = '(&%s)' % ''.join(bits)
113         elif self.connector == OR:
114             sql_string = '(|%s)' % ''.join(bits)
115         else:
116             raise Exception("Unhandled WHERE connector: %s" % self.connector)
117         return sql_string, []
118
119 class Query(BaseQuery):
120     def __init__(self, *args, **kwargs):
121         super(Query, self).__init__(*args, **kwargs)
122         self.connection = ldapdb.connection
123
124     def get_count(self):
125         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
126         sql, params = self.where.as_sql()
127         filterstr += sql
128         filterstr = '(&%s)' % filterstr
129
130         try:
131             vals = ldapdb.connection.search_s(
132                 self.model.base_dn,
133                 ldap.SCOPE_SUBTREE,
134                 filterstr=filterstr,
135                 attrlist=[],
136             )
137         except:
138             raise self.model.DoesNotExist
139
140         return len(vals)
141
142     def results_iter(self):
143         filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
144         sql, params = self.where.as_sql()
145         filterstr += sql
146         filterstr = '(&%s)' % filterstr
147         attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
148
149         try:
150             vals = ldapdb.connection.search_s(
151                 self.model.base_dn,
152                 ldap.SCOPE_SUBTREE,
153                 filterstr=filterstr,
154                 attrlist=attrlist,
155             )
156         except:
157             raise self.model.DoesNotExist
158
159         # perform sorting
160         if self.extra_order_by:
161             ordering = self.extra_order_by
162         elif not self.default_ordering:
163             ordering = self.order_by
164         else:
165             ordering = self.order_by or self.model._meta.ordering
166         def cmpvals(x, y):
167             for field in ordering:
168                 if field.startswith('-'):
169                     field = field[1:]
170                     negate = True
171                 else:
172                     negate = False
173                 attr = self.model._meta.get_field(field).db_column
174                 attr_x = x[1].get(attr, '').lower()
175                 attr_y = y[1].get(attr, '').lower()
176                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
177                 if val:
178                     return val
179             return 0
180         vals = sorted(vals, cmp=cmpvals)
181
182         # process results
183         for dn, attrs in vals:
184             row = []
185             for field in iter(self.model._meta.fields):
186                 if field.attname == 'dn':
187                     row.append(dn)
188                 else:
189                     row.append(attrs.get(field.db_column, None))
190             yield row
191
192 class QuerySet(BaseQuerySet):
193     def __init__(self, model=None, query=None, using=None):
194         if not query:
195             import inspect
196             spec = inspect.getargspec(BaseQuery.__init__)
197             if len(spec[0]) == 3:
198                 # django 1.2
199                 query = Query(model, WhereNode)
200             else:
201                 # django 1.1
202                 query = Query(model, None, WhereNode)
203         super(QuerySet, self).__init__(model=model, query=query)
204