add support for "greater or equal" and "less or equal" lookups on IntegerFields
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class WhereNode(BaseWhereNode):
67     def add(self, data, connector):
68         if not isinstance(data, (list, tuple)):
69             super(WhereNode, self).add(data, connector)
70             return
71
72         # we replace the native Constraint by our own
73         obj, lookup_type, value = data
74         if hasattr(obj, "process"):
75             obj = Constraint(obj.alias, obj.col, obj.field)
76         super(WhereNode, self).add((obj, lookup_type, value), connector)
77
78     def as_sql(self):
79         bits = []
80         for item in self.children:
81             if isinstance(item, WhereNode):
82                 bits.append(item.as_sql())
83                 continue
84             constraint, x, y, values = item
85             if hasattr(constraint, "col"):
86                 # django 1.2
87                 comp = get_lookup_operator(constraint.lookup_type)
88                 clause = "(%s%s%s)" % (constraint.col, comp, values)
89             else:
90                 # django 1.1
91                 (table, column, db_type) = constraint
92                 comp = get_lookup_operator(x)
93                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
94                 if len(equal_bits) == 1:
95                     clause = equal_bits[0]
96                 else:
97                     clause = '(|%s)' % ''.join(equal_bits)
98             if self.negated:
99                 bits.append('(!%s)' % clause)
100             else:
101                 bits.append(clause)
102         if len(bits) == 1:
103             return bits[0]
104         elif self.connector == AND:
105             return '(&%s)' % ''.join(bits)
106         elif self.connector == OR:
107             return '(|%s)' % ''.join(bits)
108         else:
109             raise Exception("Unhandled WHERE connector: %s" % self.connector)
110
111 class Query(BaseQuery):
112     def results_iter(self):
113         # FIXME: use all object classes
114         filterstr = '(objectClass=%s)' % self.model.object_classes[0]
115         filterstr += self.where.as_sql()
116         filterstr = '(&%s)' % filterstr
117         attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
118
119         try:
120             vals = ldapdb.connection.search_s(
121                 self.model.base_dn,
122                 ldap.SCOPE_SUBTREE,
123                 filterstr=filterstr,
124                 attrlist=attrlist,
125             )
126         except:
127             raise self.model.DoesNotExist
128
129         # perform sorting
130         if self.extra_order_by:
131             ordering = self.extra_order_by
132         elif not self.default_ordering:
133             ordering = self.order_by
134         else:
135             ordering = self.order_by or self.model._meta.ordering
136         def getkey(x):
137             keys = []
138             for k in ordering:
139                 attr = self.model._meta.get_field(k).db_column
140                 keys.append(x[1].get(attr, '').lower())
141             return keys
142         vals = sorted(vals, key=lambda x: getkey(x))
143
144         # process results
145         for dn, attrs in vals:
146             row = []
147             for field in iter(self.model._meta.fields):
148                 if field.attname == 'dn':
149                     row.append(dn)
150                 else:
151                     row.append(attrs.get(field.db_column, None))
152             yield row
153
154 class QuerySet(BaseQuerySet):
155     def __init__(self, model=None, query=None):
156         if not query:
157             query = Query(model, None, WhereNode)
158         super(QuerySet, self).__init__(model, query)
159