67d193607b7bc77c09ac74c4a802388a185df6f1
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 class Constraint(BaseConstraint):
34     """
35     An object that can be passed to WhereNode.add() and knows how to
36     pre-process itself prior to including in the WhereNode.
37     """
38     def process(self, lookup_type, value):
39         """
40         Returns a tuple of data suitable for inclusion in a WhereNode
41         instance.
42         """
43         # Because of circular imports, we need to import this here.
44         from django.db.models.base import ObjectDoesNotExist
45
46         try:
47             if self.field:
48                 params = self.field.get_db_prep_lookup(lookup_type, value)
49                 db_type = self.field.db_type()
50             else:
51                 params = CharField().get_db_prep_lookup(lookup_type, value)
52                 db_type = None
53         except ObjectDoesNotExist:
54             raise EmptyShortCircuit
55
56         return (self.alias, self.col, db_type), params
57
58 class WhereNode(BaseWhereNode):
59     def add(self, data, connector):
60         if not isinstance(data, (list, tuple)):
61             super(WhereNode, self).add(data, connector)
62             return
63
64         # we replace the native Constraint by our own
65         obj, lookup_type, value = data
66         if hasattr(obj, "process"):
67             obj = Constraint(obj.alias, obj.col, obj.field)
68         super(WhereNode, self).add((obj, lookup_type, value), connector)
69
70     def as_sql(self):
71         bits = []
72         for item in self.children:
73             if isinstance(item, WhereNode):
74                 bits.append(item.as_sql())
75                 continue
76             constraint, x, y, values = item
77             if hasattr(constraint, "col"):
78                 # django 1.2
79                 clause = "(%s=%s)" % (constraint.col, values)
80             else:
81                 # django 1.1
82                 (table, column, type) = constraint
83                 equal_bits = [ "(%s=%s)" % (column, value) for value in values ]
84                 if len(equal_bits) == 1:
85                     clause = equal_bits[0]
86                 else:
87                     clause = '(|%s)' % ''.join(equal_bits)
88             if self.negated:
89                 bits.append('(!%s)' % clause)
90             else:
91                 bits.append(clause)
92         if len(bits) == 1:
93             return bits[0]
94         elif self.connector == AND:
95             return '(&%s)' % ''.join(bits)
96         elif self.connector == OR:
97             return '(|%s)' % ''.join(bits)
98         else:
99             raise Exception("Unhandled WHERE connector: %s" % self.connector)
100
101 class Query(BaseQuery):
102     def results_iter(self):
103         # FIXME: use all object classes
104         filterstr = '(objectClass=%s)' % self.model.object_classes[0]
105         filterstr += self.where.as_sql()
106         filterstr = '(&%s)' % filterstr
107         attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
108
109         try:
110             vals = ldapdb.connection.search_s(
111                 self.model.base_dn,
112                 ldap.SCOPE_SUBTREE,
113                 filterstr=filterstr,
114                 attrlist=attrlist,
115             )
116         except:
117             raise self.model.DoesNotExist
118
119         # perform sorting
120         if self.extra_order_by:
121             ordering = self.extra_order_by
122         elif not self.default_ordering:
123             ordering = self.order_by
124         else:
125             ordering = self.order_by or self.model._meta.ordering
126         def getkey(x):
127             keys = []
128             for k in ordering:
129                 attr = self.model._meta.get_field(k).db_column
130                 keys.append(x[1].get(attr, '').lower())
131             return keys
132         vals = sorted(vals, key=lambda x: getkey(x))
133
134         # process results
135         for dn, attrs in vals:
136             row = []
137             for field in iter(self.model._meta.fields):
138                 if field.attname == 'dn':
139                     row.append(dn)
140                 else:
141                     row.append(attrs.get(field.db_column, None))
142             yield row
143
144 class QuerySet(BaseQuerySet):
145     def __init__(self, model=None, query=None):
146         if not query:
147             query = Query(model, None, WhereNode)
148         super(QuerySet, self).__init__(model, query)
149