+ pos += 1
+
+
+class WhereNode(BaseWhereNode):
+ def add(self, data, connector):
+ if not isinstance(data, (list, tuple)):
+ super(WhereNode, self).add(data, connector)
+ return
+
+ # we replace the native Constraint by our own
+ obj, lookup_type, value = data
+ if hasattr(obj, "process"):
+ obj = Constraint(obj.alias, obj.col, obj.field)
+ super(WhereNode, self).add((obj, lookup_type, value), connector)
+
+ def as_sql(self, qn=None, connection=None):
+ bits = []
+ for item in self.children:
+ if hasattr(item, 'as_sql'):
+ sql, params = item.as_sql(qn=qn, connection=connection)
+ bits.append(sql)
+ continue
+
+ constraint, lookup_type, y, values = item
+ comp = get_lookup_operator(lookup_type)
+ if lookup_type == 'in':
+ equal_bits = [ "(%s%s%s)" % (constraint.col, comp, value) for value in values ]
+ clause = '(|%s)' % ''.join(equal_bits)
+ else:
+ clause = "(%s%s%s)" % (constraint.col, comp, values)
+
+ bits.append(clause)
+
+ if not len(bits):
+ return '', []
+
+ if len(bits) == 1:
+ sql_string = bits[0]
+ elif self.connector == AND:
+ sql_string = '(&%s)' % ''.join(bits)
+ elif self.connector == OR:
+ sql_string = '(|%s)' % ''.join(bits)
+ else:
+ raise Exception("Unhandled WHERE connector: %s" % self.connector)
+
+ if self.negated:
+ sql_string = ('(!%s)' % sql_string)
+
+ return sql_string, []
+
+class Query(BaseQuery):
+ def __init__(self, *args, **kwargs):
+ super(Query, self).__init__(*args, **kwargs)
+ self.connection = ldapdb.connection
+
+ def _ldap_filter(self):
+ filterstr = ''.join(['(objectClass=%s)' % cls for cls in self.model.object_classes])
+ sql, params = self.where.as_sql()
+ filterstr += sql
+ return '(&%s)' % filterstr
+
+ def get_count(self, using):
+ try:
+ vals = ldapdb.connection.search_s(
+ self.model.base_dn,
+ self.model.search_scope,
+ filterstr=self._ldap_filter(),
+ attrlist=[],
+ )
+ except ldap.NO_SUCH_OBJECT:
+ return 0
+
+ number = len(vals)
+
+ # apply limit and offset
+ number = max(0, number - self.low_mark)
+ if self.high_mark is not None:
+ number = min(number, self.high_mark - self.low_mark)
+
+ return number
+
+ def get_compiler(self, using=None, connection=None):
+ return Compiler(self, ldapdb.connection, using)
+
+ def has_results(self, using):
+ return self.get_count(using) != 0