X-Git-Url: https://git.stderr.nl/gitweb?p=matthijs%2Fupstream%2Fdjango-ldapdb.git;a=blobdiff_plain;f=examples%2Ftests.py;h=65f051b80b20cbc2d99e4c8fd50d31605212b470;hp=1b4a42ca4ec11f514a2328c6adaabfd7c78a2874;hb=649f74436527abbe004a04856a1be212b972057b;hpb=50d3fcb1ad4326a55bb156fd641ce40bf52a9a51 diff --git a/examples/tests.py b/examples/tests.py index 1b4a42c..65f051b 100644 --- a/examples/tests.py +++ b/examples/tests.py @@ -32,6 +32,7 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +from django.db import connections, router from django.db.models import Q from django.test import TestCase @@ -39,26 +40,30 @@ import ldap import ldapdb from ldapdb.backends.ldap.compiler import query_as_ldap from examples.models import LdapUser, LdapGroup - + class BaseTestCase(TestCase): def setUp(self): - cursor = ldapdb.connection._cursor() - for dn in [LdapGroup.base_dn, LdapUser.base_dn]: - rdn = dn.split(',')[0] + for model in [LdapGroup, LdapUser]: + using = router.db_for_write(model) + connection = connections[using] + + rdn = model.base_dn.split(',')[0] key, val = rdn.split('=') attrs = [('objectClass', ['top', 'organizationalUnit']), (key, [val])] try: - cursor.connection.add_s(dn, attrs) + connection.add_s(model.base_dn, attrs) except ldap.ALREADY_EXISTS: pass def tearDown(self): - cursor = ldapdb.connection._cursor() - for base in [LdapGroup.base_dn, LdapUser.base_dn]: + for model in [LdapGroup, LdapUser]: + using = router.db_for_write(model) + connection = connections[using] + try: - results = cursor.connection.search_s(base, ldap.SCOPE_SUBTREE) + results = connection.search_s(model.base_dn, ldap.SCOPE_SUBTREE) for dn, attrs in reversed(results): - cursor.connection.delete_s(dn) + connection.delete_s(dn) except ldap.NO_SUCH_OBJECT: pass @@ -311,10 +316,9 @@ class ScopedTestCase(BaseTestCase): def setUp(self): super(ScopedTestCase, self).setUp() - cursor = ldapdb.connection._cursor() self.scoped_dn = "ou=contacts,%s" % LdapGroup.base_dn attrs = [('objectClass', ['top', 'organizationalUnit']), ("ou", ["contacts"])] - cursor.connection.add_s(self.scoped_dn, attrs) + ldapdb.connection.add_s(self.scoped_dn, attrs) def test_scope(self): ScopedGroup = LdapGroup.scoped(self.scoped_dn)