b8e4897a645e1132424c1b735aff28928235408d
[matthijs/upstream/django-ldapdb.git] / ldapdb / models / query.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (C) 2009-2010 BollorĂ© telecom
5 # See AUTHORS file for a full list of contributors.
6
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 from copy import deepcopy
22 import ldap
23
24 from django.db.models.query import QuerySet as BaseQuerySet
25 from django.db.models.query_utils import Q
26 from django.db.models.sql import Query as BaseQuery
27 from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as BaseConstraint, AND, OR
28
29 import ldapdb
30
31 from ldapdb.models.fields import CharField
32
33 def get_lookup_operator(lookup_type):
34     if lookup_type == 'gte':
35         return '>='
36     elif lookup_type == 'lte':
37         return '<='
38     else:
39         return '='
40
41 class Constraint(BaseConstraint):
42     """
43     An object that can be passed to WhereNode.add() and knows how to
44     pre-process itself prior to including in the WhereNode.
45     """
46     def process(self, lookup_type, value):
47         """
48         Returns a tuple of data suitable for inclusion in a WhereNode
49         instance.
50         """
51         # Because of circular imports, we need to import this here.
52         from django.db.models.base import ObjectDoesNotExist
53
54         try:
55             if self.field:
56                 params = self.field.get_db_prep_lookup(lookup_type, value)
57                 db_type = self.field.db_type()
58             else:
59                 params = CharField().get_db_prep_lookup(lookup_type, value)
60                 db_type = None
61         except ObjectDoesNotExist:
62             raise EmptyShortCircuit
63
64         return (self.alias, self.col, db_type), params
65
66 class WhereNode(BaseWhereNode):
67     def add(self, data, connector):
68         if not isinstance(data, (list, tuple)):
69             super(WhereNode, self).add(data, connector)
70             return
71
72         # we replace the native Constraint by our own
73         obj, lookup_type, value = data
74         if hasattr(obj, "process"):
75             obj = Constraint(obj.alias, obj.col, obj.field)
76         super(WhereNode, self).add((obj, lookup_type, value), connector)
77
78     def as_sql(self):
79         bits = []
80         for item in self.children:
81             if isinstance(item, WhereNode):
82                 bits.append(item.as_sql())
83                 continue
84             constraint, lookup_type, y, values = item
85             comp = get_lookup_operator(lookup_type)
86             if hasattr(constraint, "col"):
87                 # django 1.2
88                 clause = "(%s%s%s)" % (constraint.col, comp, values)
89             else:
90                 # django 1.1
91                 (table, column, db_type) = constraint
92                 equal_bits = [ "(%s%s%s)" % (column, comp, value) for value in values ]
93                 if len(equal_bits) == 1:
94                     clause = equal_bits[0]
95                 else:
96                     clause = '(|%s)' % ''.join(equal_bits)
97             if self.negated:
98                 bits.append('(!%s)' % clause)
99             else:
100                 bits.append(clause)
101         if len(bits) == 1:
102             return bits[0]
103         elif self.connector == AND:
104             return '(&%s)' % ''.join(bits)
105         elif self.connector == OR:
106             return '(|%s)' % ''.join(bits)
107         else:
108             raise Exception("Unhandled WHERE connector: %s" % self.connector)
109
110 class Query(BaseQuery):
111     def results_iter(self):
112         # FIXME: use all object classes
113         filterstr = '(objectClass=%s)' % self.model.object_classes[0]
114         filterstr += self.where.as_sql()
115         filterstr = '(&%s)' % filterstr
116         attrlist = [ x.db_column for x in self.model._meta.local_fields if x.db_column ]
117
118         try:
119             vals = ldapdb.connection.search_s(
120                 self.model.base_dn,
121                 ldap.SCOPE_SUBTREE,
122                 filterstr=filterstr,
123                 attrlist=attrlist,
124             )
125         except:
126             raise self.model.DoesNotExist
127
128         # perform sorting
129         if self.extra_order_by:
130             ordering = self.extra_order_by
131         elif not self.default_ordering:
132             ordering = self.order_by
133         else:
134             ordering = self.order_by or self.model._meta.ordering
135         def getkey(x):
136             keys = []
137             for k in ordering:
138                 attr = self.model._meta.get_field(k).db_column
139                 keys.append(x[1].get(attr, '').lower())
140             return keys
141         vals = sorted(vals, key=lambda x: getkey(x))
142
143         # process results
144         for dn, attrs in vals:
145             row = []
146             for field in iter(self.model._meta.fields):
147                 if field.attname == 'dn':
148                     row.append(dn)
149                 else:
150                     row.append(attrs.get(field.db_column, None))
151             yield row
152
153 class QuerySet(BaseQuerySet):
154     def __init__(self, model=None, query=None):
155         if not query:
156             query = Query(model, None, WhereNode)
157         super(QuerySet, self).__init__(model, query)
158