allow dynamically adding object classes to types
[matthijs/upstream/django-ldapdb.git] / ldapdb / backends / ldap / compiler.py
1 # -*- coding: utf-8 -*-
2
3 # django-ldapdb
4 # Copyright (c) 2009-2010, Bolloré telecom
5 # All rights reserved.
6
7 # See AUTHORS file for a full list of contributors.
8
9 # Redistribution and use in source and binary forms, with or without modification,
10 # are permitted provided that the following conditions are met:
11
12 #     1. Redistributions of source code must retain the above copyright notice, 
13 #        this list of conditions and the following disclaimer.
14 #     
15 #     2. Redistributions in binary form must reproduce the above copyright 
16 #        notice, this list of conditions and the following disclaimer in the
17 #        documentation and/or other materials provided with the distribution.
18
19 #     3. Neither the name of Bolloré telecom nor the names of its contributors
20 #        may be used to endorse or promote products derived from this software
21 #        without specific prior written permission.
22
23 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
24 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
25 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
27 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
30 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 #
34
35 import ldap
36
37 from django.db.models.sql import aggregates, compiler
38 from django.db.models.sql.where import AND, OR
39
40 def get_lookup_operator(lookup_type):
41     if lookup_type == 'gte':
42         return '>='
43     elif lookup_type == 'lte':
44         return '<='
45     else:
46         return '='
47
48 def query_as_ldap(query):
49     # TODO: Filtering on objectClass temporarily disabled, since this
50     # breaks Model.save() after an objectclass was added (it queries the
51     # database for the old values to see what changed, but filtering on
52     # the new objectclasses does not return the object).
53     #filterstr = ''.join(['(objectClass=%s)' % cls for cls in query.model.object_classes])
54     filterstr = ''
55     sql, params = where_as_ldap(query.where)
56     filterstr += sql
57     return '(&%s)' % filterstr
58
59 def where_as_ldap(self):
60     bits = []
61     for item in self.children:
62         if hasattr(item, 'as_sql'):
63             sql, params = where_as_ldap(item)
64             bits.append(sql)
65             continue
66
67         constraint, lookup_type, y, values = item
68         comp = get_lookup_operator(lookup_type)
69         if lookup_type == 'in':
70             equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
71             clause = '(|%s)' % ''.join(equal_bits)
72         else:
73             clause = "(%s%s%s)" % (constraint.col, comp, values)
74
75         bits.append(clause)
76
77     if not len(bits):
78         return '', []
79
80     if len(bits) == 1:
81         sql_string = bits[0]
82     elif self.connector == AND:
83         sql_string = '(&%s)' % ''.join(bits)
84     elif self.connector == OR:
85         sql_string = '(|%s)' % ''.join(bits)
86     else:
87         raise Exception("Unhandled WHERE connector: %s" % self.connector)
88
89     if self.negated:
90         sql_string = ('(!%s)' % sql_string)
91
92     return sql_string, []
93
94 class SQLCompiler(object):
95     def __init__(self, query, connection, using):
96         self.query = query
97         self.connection = connection
98         self.using = using
99
100     def execute_sql(self, result_type=compiler.MULTI):
101         if result_type !=compiler.SINGLE:
102             raise Exception("LDAP does not support MULTI queries")
103
104         for key, aggregate in self.query.aggregate_select.items():
105             if not isinstance(aggregate, aggregates.Count):
106                 raise Exception("Unsupported aggregate %s" % aggregate)
107
108         try:
109             vals = self.connection.search_s(
110                 self.query.model.base_dn,
111                 self.query.model.search_scope,
112                 filterstr=query_as_ldap(self.query),
113                 attrlist=['dn'],
114             )
115         except ldap.NO_SUCH_OBJECT:
116             vals = []
117
118         if not vals:
119             return None
120
121         output = []
122         for alias, col in self.query.extra_select.iteritems():
123             output.append(col[0])
124         for key, aggregate in self.query.aggregate_select.items():
125             if isinstance(aggregate, aggregates.Count):
126                 output.append(len(vals))
127             else:
128                 output.append(None)
129         return output
130
131     def results_iter(self):
132         if self.query.select_fields:
133             fields = self.query.select_fields
134         else:
135             fields = self.query.model._meta.fields
136
137         attrlist = [ x.db_column for x in fields if x.db_column ]
138
139         try:
140             vals = self.connection.search_s(
141                 self.query.model.base_dn,
142                 self.query.model.search_scope,
143                 filterstr=query_as_ldap(self.query),
144                 attrlist=attrlist,
145             )
146         except ldap.NO_SUCH_OBJECT:
147             return
148
149         # perform sorting
150         if self.query.extra_order_by:
151             ordering = self.query.extra_order_by
152         elif not self.query.default_ordering:
153             ordering = self.query.order_by
154         else:
155             ordering = self.query.order_by or self.query.model._meta.ordering
156         def cmpvals(x, y):
157             for fieldname in ordering:
158                 if fieldname.startswith('-'):
159                     fieldname = fieldname[1:]
160                     negate = True
161                 else:
162                     negate = False
163                 field = self.query.model._meta.get_field(fieldname)
164                 attr_x = field.from_ldap(x[1].get(field.db_column, []), connection=self.connection)
165                 attr_y = field.from_ldap(y[1].get(field.db_column, []), connection=self.connection)
166                 # perform case insensitive comparison
167                 if hasattr(attr_x, 'lower'):
168                     attr_x = attr_x.lower()
169                 if hasattr(attr_y, 'lower'):
170                     attr_y = attr_y.lower()
171                 val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
172                 if val:
173                     return val
174             return 0
175         vals = sorted(vals, cmp=cmpvals)
176
177         # process results
178         pos = 0
179         for dn, attrs in vals:
180             # FIXME : This is not optimal, we retrieve more results than we need
181             # but there is probably no other options as we can't perform ordering
182             # server side.
183             if (self.query.low_mark and pos < self.query.low_mark) or \
184                (self.query.high_mark is not None and pos >= self.query.high_mark):
185                 pos += 1
186                 continue
187             row = []
188             for field in iter(fields):
189                 if field.attname == 'dn':
190                     row.append(dn)
191                 elif hasattr(field, 'from_ldap'):
192                     row.append(field.from_ldap(attrs.get(field.db_column, []), connection=self.connection))
193                 else:
194                     row.append(None)
195             yield row
196             pos += 1
197
198 class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
199     pass
200
201 class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
202     def execute_sql(self, result_type=compiler.MULTI):
203         try:
204             vals = self.connection.search_s(
205                 self.query.model.base_dn,
206                 self.query.model.search_scope,
207                 filterstr=query_as_ldap(self.query),
208                 attrlist=['dn'],
209             )
210         except ldap.NO_SUCH_OBJECT:
211             return
212
213         # FIXME : there is probably a more efficient way to do this 
214         for dn, attrs in vals:
215             self.connection.delete_s(dn)
216
217 class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
218     pass
219
220 class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
221     pass
222
223 class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
224     pass
225