fix lookups with django 1.2
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class WhereNode(BaseWhereNode):
67     def add(self, data, connector):
68         if not isinstance(data, (list, tuple)):
69             super(WhereNode, self).add(data, connector)
70             return
71
72         # we replace the native Constraint by our own
73         obj, lookup_type, value = data
74         if hasattr(obj, "process"):
75             obj = Constraint(obj.alias, obj.col, obj.field)
76         super(WhereNode, self).add((obj, lookup_type, value), connector)
77
78     def as_sql(self):
79         bits = []
80         for item in self.children:
81             if isinstance(item, WhereNode):
82                 bits.append(item.as_sql())
83                 continue
84
85             constraint, lookup_type, y, values = item
86             comp = get_lookup_operator(lookup_type)
87             if hasattr(constraint, "col"):
88                 # django 1.2
89                 column = constraint.col
90                 if lookup_type == 'in':
91                     equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
92                     clause = '(|%s)' % ''.join(equal_bits)
93                 else:
94                     clause = "(%s%s%s)" % (constraint.col, comp, values)
95             else:
96                 # django 1.1
97                 (table, column, db_type) = constraint
98                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
99                 if len(equal_bits) == 1:
100                     clause = equal_bits[0]
101                 else:
102                     clause = '(|%s)' % ''.join(equal_bits)
103
104             if self.negated:
105                 bits.append('(!%s)' % clause)
106             else:
107                 bits.append(clause)
108         if len(bits) == 1:
109             return bits[0]
110         elif self.connector == AND:
111             return '(&%s)' % ''.join(bits)
112         elif self.connector == OR:
113             return '(|%s)' % ''.join(bits)
114         else:
115             raise Exception("Unhandled WHERE connector: %s" % self.connector)
116
117 class Query(BaseQuery):
118     def results_iter(self):
119         # FIXME: use all object classes
120         filterstr = '(objectClass=%s)' % self.model.object_classes[0]
121         filterstr += self.where.as_sql()
122         filterstr = '(&%s)' % filterstr
123         attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
124
125         try:
126             vals = ldapdb.connection.search_s(
127                 self.model.base_dn,
128                 ldap.SCOPE_SUBTREE,
129                 filterstr=filterstr,
130                 attrlist=attrlist,
131             )
132         except:
133             raise self.model.DoesNotExist
134
135         # perform sorting
136         if self.extra_order_by:
137             ordering = self.extra_order_by
138         elif not self.default_ordering:
139             ordering = self.order_by
140         else:
141             ordering = self.order_by or self.model._meta.ordering
142         def getkey(x):
143             keys = []
144             for k in ordering:
145                 attr = self.model._meta.get_field(k).db_column
146                 keys.append(x[1].get(attr, '').lower())
147             return keys
148         vals = sorted(vals, key=lambda x: getkey(x))
149
150         # process results
151         for dn, attrs in vals:
152             row = []
153             for field in iter(self.model._meta.fields):
154                 if field.attname == 'dn':
155                     row.append(dn)
156                 else:
157                     row.append(attrs.get(field.db_column, None))
158             yield row
159
160 class QuerySet(BaseQuerySet):
161     def __init__(self, model=None, query=None):
162         if not query:
163             query = Query(model, None, WhereNode)
164         super(QuerySet, self).__init__(model, query)
165