import functools import os import time import socket from django.conf import settings as djsettings from django.core.urlresolvers import reverse from selenium.webdriver.support.select import Select from orchestra.apps.orchestration.models import Server, Route from orchestra.utils.tests import BaseLiveServerTestCase, random_ascii, snapshot_on_error from orchestra.utils.system import run from ... import settings, utils, backends from ...models import Domain, Record run = functools.partial(run, display=False) class DomainTestMixin(object): MASTER_SERVER = os.environ.get('ORCHESTRA_MASTER_SERVER', 'localhost') SLAVE_SERVER = os.environ.get('ORCHESTRA_SLAVE_SERVER', 'localhost') MASTER_SERVER_ADDR = socket.gethostbyname(MASTER_SERVER) SLAVE_SERVER_ADDR = socket.gethostbyname(SLAVE_SERVER) def setUp(self): djsettings.DEBUG = True super(DomainTestMixin, self).setUp() self.domain_name = 'orchestra%s.lan' % random_ascii(10) self.domain_records = ( (Record.MX, '10 mail.orchestra.lan.'), (Record.MX, '20 mail2.orchestra.lan.'), (Record.NS, 'ns1.%s.' % self.domain_name), (Record.NS, 'ns2.%s.' % self.domain_name), ) self.domain_update_records = ( (Record.MX, '30 mail3.orchestra.lan.'), (Record.MX, '40 mail4.orchestra.lan.'), (Record.NS, 'ns1.%s.' % self.domain_name), (Record.NS, 'ns2.%s.' % self.domain_name), ) self.ns1_name = 'ns1.%s' % self.domain_name self.ns1_records = ( (Record.A, '%s' % self.SLAVE_SERVER_ADDR), ) self.ns2_name = 'ns2.%s' % self.domain_name self.ns2_records = ( (Record.A, '%s' % self.MASTER_SERVER_ADDR), ) self.www_name = 'www.%s' % self.domain_name self.www_records = ( (Record.CNAME, 'external.server.org.'), ) self.django_domain_name = 'django%s.lan' % random_ascii(10) def tearDown(self): try: self.delete(self.domain_name) except Domain.DoesNotExist: pass super(DomainTestMixin, self).tearDown() def add_route(self): raise NotImplementedError def add(self, domain_name, records): raise NotImplementedError def delete(self, domain_name, records): raise NotImplementedError def update(self, domain_name, records): raise NotImplementedError def validate_add(self, server_addr, domain_name): context = { 'domain_name': domain_name, 'server_addr': server_addr } dig_soa = 'dig @%(server_addr)s %(domain_name)s SOA|grep "\sSOA\s"' soa = run(dig_soa % context).stdout.split() # testdomain.org. 3600 IN SOA ns.example.com. hostmaster.example.com. 2014021100 86400 7200 2419200 3600 self.assertEqual('%(domain_name)s.' % context, soa[0]) self.assertEqual('3600', soa[1]) self.assertEqual('IN', soa[2]) self.assertEqual('SOA', soa[3]) self.assertEqual('%s.' % settings.DOMAINS_DEFAULT_NAME_SERVER, soa[4]) hostmaster = utils.format_hostmaster(settings.DOMAINS_DEFAULT_HOSTMASTER) self.assertEqual(hostmaster, soa[5]) dig_ns = 'dig @%(server_addr)s %(domain_name)s NS|grep "\sNS\s"' name_servers = run(dig_ns % context).stdout # testdomain.org. 3600 IN NS ns1.orchestra.lan. ns_records = ['ns1.%s.' % self.domain_name, 'ns2.%s.' % self.domain_name] self.assertEqual(2, len(name_servers.splitlines())) for ns in name_servers.splitlines(): ns = ns.split() # testdomain.org. 3600 IN NS ns1.orchestra.lan. self.assertEqual('%(domain_name)s.' % context, ns[0]) self.assertEqual('3600', ns[1]) self.assertEqual('IN', ns[2]) self.assertEqual('NS', ns[3]) self.assertIn(ns[4], ns_records) dig_mx = 'dig @%(server_addr)s %(domain_name)s MX|grep "\sMX\s"' mail_servers = run(dig_mx % context).stdout for mx in mail_servers.splitlines(): mx = mx.split() # testdomain.org. 3600 IN NS ns1.orchestra.lan. self.assertEqual('%(domain_name)s.' % context, mx[0]) self.assertEqual('3600', mx[1]) self.assertEqual('IN', mx[2]) self.assertEqual('MX', mx[3]) self.assertIn(mx[4], ['10', '20']) self.assertIn(mx[5], ['mail2.orchestra.lan.', 'mail.orchestra.lan.']) def validate_delete(self, server_addr, domain_name): context = { 'domain_name': domain_name, 'server_addr': server_addr } dig_soa = 'dig @%(server_addr)s %(domain_name)s|grep "\sSOA\s"' soa = run(dig_soa % context, error_codes=[0,1]).stdout if soa: soa = soa.split() self.assertEqual('IN', soa[2]) self.assertEqual('SOA', soa[3]) self.assertNotEqual('%s.' % settings.DOMAINS_DEFAULT_NAME_SERVER, soa[4]) hostmaster = utils.format_hostmaster(settings.DOMAINS_DEFAULT_HOSTMASTER) self.assertNotEqual(hostmaster, soa[5]) def validate_update(self, server_addr, domain_name): Domain.objects.get(name=domain_name) context = { 'domain_name': domain_name, 'server_addr': server_addr } dig_soa = 'dig @%(server_addr)s %(domain_name)s SOA|grep "\sSOA\s"' soa = run(dig_soa % context).stdout.split() # testdomain.org. 3600 IN SOA ns.example.com. hostmaster.example.com. 2014021100 86400 7200 2419200 3600 self.assertEqual('%(domain_name)s.' % context, soa[0]) self.assertEqual('3600', soa[1]) self.assertEqual('IN', soa[2]) self.assertEqual('SOA', soa[3]) self.assertEqual('%s.' % settings.DOMAINS_DEFAULT_NAME_SERVER, soa[4]) hostmaster = utils.format_hostmaster(settings.DOMAINS_DEFAULT_HOSTMASTER) self.assertEqual(hostmaster, soa[5]) dig_ns = 'dig @%(server_addr)s %(domain_name)s NS|grep "\sNS\s"' name_servers = run(dig_ns % context).stdout ns_records = ['ns1.%s.' % self.domain_name, 'ns2.%s.' % self.domain_name] self.assertEqual(2, len(name_servers.splitlines())) for ns in name_servers.splitlines(): ns = ns.split() # testdomain.org. 3600 IN NS ns1.orchestra.lan. self.assertEqual('%(domain_name)s.' % context, ns[0]) self.assertEqual('3600', ns[1]) self.assertEqual('IN', ns[2]) self.assertEqual('NS', ns[3]) self.assertIn(ns[4], ns_records) dig_mx = 'dig @%(server_addr)s %(domain_name)s MX|grep "\sMX\s"' mx = run(dig_mx % context).stdout.split() # testdomain.org. 3600 IN MX 10 orchestra.lan. self.assertEqual('%(domain_name)s.' % context, mx[0]) self.assertEqual('3600', mx[1]) self.assertEqual('IN', mx[2]) self.assertEqual('MX', mx[3]) self.assertIn(mx[4], ['30', '40']) self.assertIn(mx[5], ['mail3.orchestra.lan.', 'mail4.orchestra.lan.']) dig_cname = 'dig @%(server_addr)s www.%(domain_name)s CNAME|grep "\sCNAME\s"' cname = run(dig_cname % context).stdout.split() # testdomain.org. 3600 IN MX 10 orchestra.lan. self.assertEqual('www.%(domain_name)s.' % context, cname[0]) self.assertEqual('3600', cname[1]) self.assertEqual('IN', cname[2]) self.assertEqual('CNAME', cname[3]) self.assertEqual('external.server.org.', cname[4]) def test_add(self): self.add(self.ns1_name, self.ns1_records) self.add(self.ns2_name, self.ns2_records) self.add(self.domain_name, self.domain_records) self.validate_add(self.MASTER_SERVER_ADDR, self.domain_name) time.sleep(0.5) self.validate_add(self.SLAVE_SERVER_ADDR, self.domain_name) def test_delete(self): self.add(self.ns1_name, self.ns1_records) self.add(self.ns2_name, self.ns2_records) self.add(self.domain_name, self.domain_records) self.delete(self.domain_name) for name in [self.domain_name, self.ns1_name, self.ns2_name]: self.validate_delete(self.MASTER_SERVER_ADDR, name) self.validate_delete(self.SLAVE_SERVER_ADDR, name) def test_update(self): self.add(self.ns1_name, self.ns1_records) self.add(self.ns2_name, self.ns2_records) self.add(self.domain_name, self.domain_records) self.update(self.domain_name, self.domain_update_records) self.add(self.www_name, self.www_records) self.validate_update(self.MASTER_SERVER_ADDR, self.domain_name) time.sleep(5) self.validate_update(self.SLAVE_SERVER_ADDR, self.domain_name) def test_add_add_delete_delete(self): self.add(self.ns1_name, self.ns1_records) self.add(self.ns2_name, self.ns2_records) self.add(self.domain_name, self.domain_records) self.add(self.django_domain_name, self.domain_records) self.delete(self.domain_name) self.validate_add(self.MASTER_SERVER_ADDR, self.django_domain_name) self.validate_add(self.SLAVE_SERVER_ADDR, self.django_domain_name) self.delete(self.django_domain_name) self.validate_delete(self.MASTER_SERVER_ADDR, self.django_domain_name) self.validate_delete(self.SLAVE_SERVER_ADDR, self.django_domain_name) def test_bad_creation(self): self.assertRaises((self.rest.ResponseStatusError, AssertionError), self.add, self.domain_name, self.domain_records) class AdminDomainMixin(DomainTestMixin): def setUp(self): super(AdminDomainMixin, self).setUp() self.add_route() self.admin_login() def _add_records(self, records): self.selenium.find_element_by_link_text('Add another Record').click() for i, record in zip(range(0, len(records)), records): type, value = record type_input = self.selenium.find_element_by_id('id_records-%d-type' % i) type_select = Select(type_input) type_select.select_by_value(type) value_input = self.selenium.find_element_by_id('id_records-%d-value' % i) value_input.clear() value_input.send_keys(value) return value_input @snapshot_on_error def add(self, domain_name, records): add = reverse('admin:domains_domain_add') url = self.live_server_url + add self.selenium.get(url) name = self.selenium.find_element_by_id('id_name') name.send_keys(domain_name) account_input = self.selenium.find_element_by_id('id_account') account_select = Select(account_input) account_select.select_by_value(str(self.account.pk)) value_input = self._add_records(records) value_input.submit() self.assertNotEqual(url, self.selenium.current_url) @snapshot_on_error def delete(self, domain_name): domain = Domain.objects.get(name=domain_name) delete = reverse('admin:domains_domain_delete', args=(domain.pk,)) url = self.live_server_url + delete self.selenium.get(url) form = self.selenium.find_element_by_name('post') form.submit() self.assertNotEqual(url, self.selenium.current_url) @snapshot_on_error def update(self, domain_name, records): domain = Domain.objects.get(name=domain_name) change = reverse('admin:domains_domain_change', args=(domain.pk,)) url = self.live_server_url + change self.selenium.get(url) value_input = self._add_records(records) value_input.submit() self.assertNotEqual(url, self.selenium.current_url) class RESTDomainMixin(DomainTestMixin): def setUp(self): super(RESTDomainMixin, self).setUp() self.rest_login() self.add_route() def add(self, domain_name, records): records = [ dict(type=type, value=value) for type,value in records ] self.rest.domains.create(name=domain_name, records=records) def delete(self, domain_name): domain = Domain.objects.get(name=domain_name) domain = self.rest.domains.retrieve(id=domain.pk) domain.delete() def update(self, domain_name, records): records = [ dict(type=type, value=value) for type,value in records ] domains = self.rest.domains.retrieve(name=domain_name) domain = domains.get() domain.update(records=records) class Bind9BackendMixin(object): DEPENDENCIES = ( 'orchestra.apps.orchestration', ) def add_route(self): master = Server.objects.create(name=self.MASTER_SERVER, address=self.MASTER_SERVER_ADDR) backend = backends.Bind9MasterDomainBackend.get_name() Route.objects.create(backend=backend, match=True, host=master) slave = Server.objects.create(name=self.SLAVE_SERVER, address=self.SLAVE_SERVER_ADDR) backend = backends.Bind9SlaveDomainBackend.get_name() Route.objects.create(backend=backend, match=True, host=slave) class RESTBind9BackendDomainTest(Bind9BackendMixin, RESTDomainMixin, BaseLiveServerTestCase): pass class AdminBind9BackendDomainTest(Bind9BackendMixin, AdminDomainMixin, BaseLiveServerTestCase): pass