improve LDAP filter escaping
authorjlaine <jlaine@e071eeec-0327-468d-9b6a-08194a12b294>
Wed, 19 May 2010 08:31:51 +0000 (08:31 +0000)
committerjlaine <jlaine@e071eeec-0327-468d-9b6a-08194a12b294>
Wed, 19 May 2010 08:31:51 +0000 (08:31 +0000)
git-svn-id: https://svn.bolloretelecom.eu/opensource/django-ldapdb/trunk@849 e071eeec-0327-468d-9b6a-08194a12b294

ldapdb/models/fields.py
ldapdb/models/query.py
ldapdb/tests.py

index fe91601f6053d8c9c74d27f3e330f880201d242d..78ec1c7569f3421d2b856e2d78856f70cbed70da 100644 (file)
@@ -25,19 +25,6 @@ class CharField(fields.CharField):
         kwargs['max_length'] = 200
         super(CharField, self).__init__(*args, **kwargs)
 
-    def get_db_prep_value(self, value):
-        """Returns field's value prepared for interacting with the database
-        backend.
-
-        Used by the default implementations of ``get_db_prep_save``and
-        `get_db_prep_lookup```
-        """
-        return value.replace('\\', '\\5c') \
-                    .replace('*', '\\2a') \
-                    .replace('(', '\\28') \
-                    .replace(')', '\\29') \
-                    .replace('\0', '\\00')
-
 class ImageField(fields.Field):
     pass
 
index f430286b65be5d14531e46a86b540eaa3a851ab6..877ea9e40ee74c0b323361e39ca2ea893ba79f4f 100644 (file)
@@ -21,7 +21,6 @@
 from copy import deepcopy
 import ldap
 
-from django.db.models.fields import Field
 from django.db.models.query import QuerySet as BaseQuerySet
 from django.db.models.query_utils import Q
 from django.db.models.sql import Query as BaseQuery
@@ -29,6 +28,14 @@ from django.db.models.sql.where import WhereNode as BaseWhereNode, Constraint as
 
 import ldapdb
 
+def escape_ldap_filter(value):
+    value = str(value)
+    return value.replace('\\', '\\5c') \
+                .replace('*', '\\2a') \
+                .replace('(', '\\28') \
+                .replace(')', '\\29') \
+                .replace('\0', '\\00')
+
 class Constraint(BaseConstraint):
     """
     An object that can be passed to WhereNode.add() and knows how to
@@ -43,13 +50,13 @@ class Constraint(BaseConstraint):
         from django.db.models.base import ObjectDoesNotExist
 
         if lookup_type == 'endswith':
-            params = ["*%s" % value]
+            params = ["*%s" % escape_ldap_filter(value)]
         elif lookup_type == 'startswith':
-            params = ["%s*" % value]
+            params = ["%s*" % escape_ldap_filter(value)]
         elif lookup_type == 'exact':
-            params = [value]
+            params = [escape_ldap_filter(value)]
         elif lookup_type == 'in':
-            params = [v for v in value]
+            params = [escape_ldap_filter(v) for v in value]
         else:
             raise TypeError("Field has invalid lookup: %s" % lookup_type)
 
index e3c787b008d3db8a7100a73f88943670cbd11cbf..be247fdbb05a06a830c4c7d3c2b3ea210b039d7d 100644 (file)
@@ -22,13 +22,16 @@ from django.test import TestCase
 from django.db.models.sql.where import Constraint, AND, OR
 
 from ldapdb.models.fields import CharField
-from ldapdb.models.query import WhereNode
-
-class FieldTestCase(TestCase):
-    def test_db_prep(self):
-        field = CharField()
+from ldapdb.models.query import WhereNode, escape_ldap_filter
 
 class WhereTestCase(TestCase):
+    def test_escape(self):
+        self.assertEquals(escape_ldap_filter('foo*bar'), 'foo\\2abar')
+        self.assertEquals(escape_ldap_filter('foo(bar'), 'foo\\28bar')
+        self.assertEquals(escape_ldap_filter('foo)bar'), 'foo\\29bar')
+        self.assertEquals(escape_ldap_filter('foo\\bar'), 'foo\\5cbar')
+        self.assertEquals(escape_ldap_filter('foo\\bar*wiz'), 'foo\\5cbar\\2awiz')
+
     def test_single(self):
         where = WhereNode()
         where.add((Constraint("cn", "cn", None), 'exact', "test"), AND)
@@ -46,6 +49,11 @@ class WhereTestCase(TestCase):
         where.add((Constraint("cn", "cn", None), 'in', ["foo", "bar"]), AND)
         self.assertEquals(where.as_sql(), "(|(cn=foo)(cn=bar))")
 
+    def test_escaped(self):
+        where = WhereNode()
+        where.add((Constraint("cn", "cn", None), 'exact', "(test)"), AND)
+        self.assertEquals(where.as_sql(), "(cn=\\28test\\29)")
+
     def test_and(self):
         where = WhereNode()
         where.add((Constraint("cn", "cn", None), 'exact', "foo"), AND)