fucked up tests

This commit is contained in:
Marc 2014-10-04 09:29:18 +00:00
parent 6a3e3f637c
commit 55f4b6e88c
11 changed files with 97 additions and 70 deletions

View file

@ -43,10 +43,11 @@ class Bind9MasterDomainBackend(ServiceController):
self.delete_conf(context) self.delete_conf(context)
def delete_conf(self, context): def delete_conf(self, context):
self.append('awk -v s=%(name)s \'BEGIN {' self.append(textwrap.dedent("""
' RS=""; s="zone \\""s"\\""' awk -v s=%(name)s 'BEGIN {
'} $0!~s{ print $0"\\n" }\' %(conf_path)s > %(conf_path)s.tmp' RS=""; s="zone \""s"\""
% context) } $0!~s{ print $0"\n" }' %(conf_path)s > %(conf_path)s.tmp""" % context
))
self.append('diff -I"^\s*//" %(conf_path)s.tmp %(conf_path)s || UPDATED=1' % context) self.append('diff -I"^\s*//" %(conf_path)s.tmp %(conf_path)s || UPDATED=1' % context)
self.append('mv %(conf_path)s.tmp %(conf_path)s' % context) self.append('mv %(conf_path)s.tmp %(conf_path)s' % context)
@ -62,13 +63,16 @@ class Bind9MasterDomainBackend(ServiceController):
servers.append(server.get_ip()) servers.append(server.get_ip())
return servers return servers
def get_slaves(self, domain):
return self.get_servers(domain, Bind9SlaveDomainBackend)
def get_context(self, domain): def get_context(self, domain):
context = { context = {
'name': domain.name, 'name': domain.name,
'zone_path': settings.DOMAINS_ZONE_PATH % {'name': domain.name}, 'zone_path': settings.DOMAINS_ZONE_PATH % {'name': domain.name},
'subdomains': domain.subdomains.all(), 'subdomains': domain.subdomains.all(),
'banner': self.get_banner(), 'banner': self.get_banner(),
'slaves': '; '.join(self.get_servers(domain, Bind9SlaveDomainBackend)), 'slaves': '; '.join(self.get_slaves(domain)) or 'none',
} }
context.update({ context.update({
'conf_path': settings.DOMAINS_MASTERS_PATH, 'conf_path': settings.DOMAINS_MASTERS_PATH,
@ -101,12 +105,15 @@ class Bind9SlaveDomainBackend(Bind9MasterDomainBackend):
""" ideally slave should be restarted after master """ """ ideally slave should be restarted after master """
self.append('[[ $UPDATED == 1 ]] && { sleep 1 && service bind9 reload; } &') self.append('[[ $UPDATED == 1 ]] && { sleep 1 && service bind9 reload; } &')
def get_masters(self, domain):
return self.get_servers(domain, Bind9MasterDomainBackend)
def get_context(self, domain): def get_context(self, domain):
context = { context = {
'name': domain.name, 'name': domain.name,
'banner': self.get_banner(), 'banner': self.get_banner(),
'subdomains': domain.subdomains.all(), 'subdomains': domain.subdomains.all(),
'masters': '; '.join(self.get_servers(domain, Bind9MasterDomainBackend)), 'masters': '; '.join(self.get_masters(domain)) or 'none',
} }
context.update({ context.update({
'conf_path': settings.DOMAINS_SLAVES_PATH, 'conf_path': settings.DOMAINS_SLAVES_PATH,

View file

@ -1,33 +1,38 @@
import copy import copy
from functools import partial
from .models import Domain, Record from .models import Domain, Record
def domain_for_validation(instance, records): def domain_for_validation(instance, records):
""" Create a fake zone in order to generate the whole zone file and check it """ """
Since the new data is not yet on the database, we update it on the fly,
so when validation calls render_zone() it will use the new provided data
"""
domain = copy.copy(instance) domain = copy.copy(instance)
if not domain.pk:
domain.top = domain.get_top()
def get_records(): def get_records():
for data in records: for data in records:
yield Record(type=data['type'], value=data['value']) yield Record(type=data['type'], value=data['value'])
domain.get_records = get_records domain.get_records = get_records
def get_top_subdomains(exclude=None): def get_subdomains(replace=None, make_top=False):
subdomains = [] for subdomain in Domain.objects.filter(name__endswith='.%s' % domain.name):
for subdomain in Domain.objects.filter(name__endswith='.%s' % domain.origin.name): if replace == subdomain.pk:
if exclude != subdomain.pk: # domain is a subdomain, yield our copy
subdomain.top = domain yield domain
else:
if make_top:
subdomain.top = domain
yield subdomain yield subdomain
domain.get_top_subdomains = get_top_subdomains
if not domain.pk:
# top domain lookup for new domains
domain.top = domain.get_top()
if domain.top: if domain.top:
subdomains = domain.get_top_subdomains(exclude=instance.pk) # is a subdomains
domain.top.get_subdomains = lambda: list(subdomains) + [domain] domain.top.get_subdomains = partial(get_subdomains, replace=domain.pk)
elif not domain.pk: elif not domain.pk:
subdomains = [] # is top domain
for subdomain in Domain.objects.filter(name__endswith=domain.name): domain.get_subdomains = partial(get_subdomains, make_top=True)
subdomain.top = domain
subdomains.append(subdomain)
domain.get_subdomains = get_top_subdomains
return domain return domain

View file

@ -24,30 +24,35 @@ class Domain(models.Model):
@property @property
def origin(self): def origin(self):
# Do not cache
return self.top or self return self.top or self
@property @property
def is_top(self): def is_top(self):
# Do not cache # don't cache, don't replace by top_id
return not bool(self.top) return not bool(self.top)
def get_records(self): def get_records(self):
""" proxy method, needed for input validation, see helpers.domain_for_validation """ """ proxy method, needed for input validation, see helpers.domain_for_validation """
return self.records.all() return self.records.all()
def get_top_subdomains(self): def get_subdomains(self):
""" proxy method, needed for input validation, see helpers.domain_for_validation """ """ proxy method, needed for input validation, see helpers.domain_for_validation """
return self.origin.subdomains.all() return self.origin.subdomains.all()
def get_subdomains(self): def get_top(self):
""" proxy method, needed for input validation, see helpers.domain_for_validation """ split = self.name.split('.')
return self.get_top_subdomains().filter(name__endswith=r'.%s' % self.name) top = None
for i in range(1, len(split)-1):
name = '.'.join(split[i:])
domain = Domain.objects.filter(name=name)
if domain:
top = domain.get()
return top
def render_zone(self): def render_zone(self):
origin = self.origin origin = self.origin
zone = origin.render_records() zone = origin.render_records()
for subdomain in origin.get_top_subdomains(): for subdomain in origin.get_subdomains():
zone += subdomain.render_records() zone += subdomain.render_records()
return zone return zone
@ -135,16 +140,6 @@ class Domain(models.Model):
domain.save(update_fields=['top']) domain.save(update_fields=['top'])
self.subdomains.update(account_id=self.account_id) self.subdomains.update(account_id=self.account_id)
def get_top(self):
split = self.name.split('.')
top = None
for i in range(1, len(split)-1):
name = '.'.join(split[i:])
domain = Domain.objects.filter(name=name)
if domain:
top = domain.get()
return top
class Record(models.Model): class Record(models.Model):
""" Represents a domain resource record """ """ Represents a domain resource record """

View file

@ -71,7 +71,7 @@ class ServiceBackend(plugins.Plugin):
def get_banner(self): def get_banner(self):
time = timezone.now().strftime("%h %d, %Y %I:%M:%S") time = timezone.now().strftime("%h %d, %Y %I:%M:%S")
return "Generated by Orchestra %s" % time return "Generated by Orchestra at %s" % time
def execute(self, server): def execute(self, server):
from .models import BackendLog from .models import BackendLog

View file

@ -37,9 +37,9 @@ def close_connection(execute):
def execute(operations): def execute(operations):
""" generates and executes the operations on the servers """ """ generates and executes the operations on the servers """
router = import_class(settings.ORCHESTRATION_ROUTER) router = import_class(settings.ORCHESTRATION_ROUTER)
# Generate scripts per server+backend
scripts = {} scripts = {}
cache = {} cache = {}
# Generate scripts per server+backend
for operation in operations: for operation in operations:
logger.debug("Queued %s" % str(operation)) logger.debug("Queued %s" % str(operation))
servers = router.get_servers(operation, cache=cache) servers = router.get_servers(operation, cache=cache)
@ -50,6 +50,7 @@ def execute(operations):
scripts[key][0].prepare() scripts[key][0].prepare()
else: else:
scripts[key][1].append(operation) scripts[key][1].append(operation)
# Get and call backend action method
method = getattr(scripts[key][0], operation.action) method = getattr(scripts[key][0], operation.action)
method(operation.instance) method(operation.instance)
# Execute scripts on each server # Execute scripts on each server
@ -67,12 +68,15 @@ def execute(operations):
executions.append((execute, operations)) executions.append((execute, operations))
[ thread.join() for thread in threads ] [ thread.join() for thread in threads ]
logs = [] logs = []
# collect results
for execution, operations in executions: for execution, operations in executions:
for operation in operations: for operation in operations:
logger.info("Executed %s" % str(operation)) logger.info("Executed %s" % str(operation))
operation.log = execution.log operation.log = execution.log
operation.save() operation.save()
logger.debug(execution.log.stdout) stdout = execution.log.stdout.strip()
logger.debug(execution.log.stderr) stdout and logger.debug('STDOUT %s', stdout)
stderr = execution.log.stderr.strip()
stderr and logger.debug('STDERR %s', stderr)
logs.append(execution.log) logs.append(execution.log)
return logs return logs

View file

@ -1,5 +1,6 @@
import hashlib import hashlib
import json import json
import logging
import os import os
import socket import socket
import sys import sys
@ -11,13 +12,16 @@ from celery.datastructures import ExceptionInfo
from . import settings from . import settings
logger = logging.getLogger(__name__)
def BashSSH(backend, log, server, cmds): def BashSSH(backend, log, server, cmds):
from .models import BackendLog from .models import BackendLog
script = '\n'.join(['set -e', 'set -o pipefail'] + cmds + ['exit 0']) script = '\n'.join(['set -e', 'set -o pipefail'] + cmds + ['exit 0'])
script = script.replace('\r', '') script = script.replace('\r', '')
log.script = script log.script = script
log.save(update_fields=['script']) log.save(update_fields=['script'])
logger.debug('%s is going to be executed on %s' % (backend, server))
try: try:
# Avoid "Argument list too long" on large scripts by genereting a file # Avoid "Argument list too long" on large scripts by genereting a file
# and scping it to the remote server # and scping it to the remote server
@ -30,15 +34,16 @@ def BashSSH(backend, log, server, cmds):
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
addr = server.get_address() addr = server.get_address()
try: try:
ssh.connect(addr, username='root', ssh.connect(addr, username='root', key_filename=settings.ORCHESTRATION_SSH_KEY_PATH)
key_filename=settings.ORCHESTRATION_SSH_KEY_PATH)
except socket.error: except socket.error:
logger.error('%s timed out on %s' % (backend, server))
log.state = BackendLog.TIMEOUT log.state = BackendLog.TIMEOUT
log.save(update_fields=['state']) log.save(update_fields=['state'])
return return
transport = ssh.get_transport() transport = ssh.get_transport()
sftp = paramiko.SFTPClient.from_transport(transport) sftp = paramiko.SFTPClient.from_transport(transport)
sftp.put(path, "%s.remote" % path) sftp.put(path, "%s.remote" % path)
logger.debug('%s copied on %s' % (backend, server))
sftp.close() sftp.close()
os.remove(path) os.remove(path)
@ -55,6 +60,7 @@ def BashSSH(backend, log, server, cmds):
channel = transport.open_session() channel = transport.open_session()
channel.exec_command(cmd) channel.exec_command(cmd)
logger.debug('%s running on %s' % (backend, server))
if True: # TODO if not async if True: # TODO if not async
log.stdout += channel.makefile('rb', -1).read().decode('utf-8') log.stdout += channel.makefile('rb', -1).read().decode('utf-8')
log.stderr += channel.makefile_stderr('rb', -1).read().decode('utf-8') log.stderr += channel.makefile_stderr('rb', -1).read().decode('utf-8')
@ -71,10 +77,12 @@ def BashSSH(backend, log, server, cmds):
break break
log.exit_code = exit_code = channel.recv_exit_status() log.exit_code = exit_code = channel.recv_exit_status()
log.state = BackendLog.SUCCESS if exit_code == 0 else BackendLog.FAILURE log.state = BackendLog.SUCCESS if exit_code == 0 else BackendLog.FAILURE
logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
channel.close() channel.close()
ssh.close() ssh.close()
log.save() log.save()
except: except:
logger.error('Exception while executing %s on %s' % (backend, server))
log.state = BackendLog.ERROR log.state = BackendLog.ERROR
log.traceback = ExceptionInfo(sys.exc_info()).traceback log.traceback = ExceptionInfo(sys.exc_info()).traceback
log.save() log.save()

View file

@ -176,7 +176,10 @@ class Route(models.Model):
for route in cls.objects.filter(is_active=True, backend=backend.get_name()): for route in cls.objects.filter(is_active=True, backend=backend.get_name()):
for action in backend.get_actions(): for action in backend.get_actions():
_key = (route.backend, action) _key = (route.backend, action)
cache[_key] = [route] try:
cache[_key].append(route)
except KeyError:
cache[_key] = [route]
routes = cache[key] routes = cache[key]
for route in routes: for route in routes:
if route.matches(operation.instance): if route.matches(operation.instance):
@ -185,7 +188,9 @@ class Route(models.Model):
def matches(self, instance): def matches(self, instance):
safe_locals = { safe_locals = {
'instance': instance 'instance': instance,
'obj': instance,
instance._meta.model_name: instance,
} }
return eval(self.match, safe_locals) return eval(self.match, safe_locals)

View file

@ -1,8 +1,7 @@
from orchestra.utils.tests import BaseTestCase from orchestra.utils.tests import BaseTestCase
from .. import operations, backends from .. import backends
from ..models import Route, Server from ..models import Route, Server, BackendOperation as Operation
from ..utils import get_backend_choices
class RouterTests(BaseTestCase): class RouterTests(BaseTestCase):
@ -18,25 +17,25 @@ class RouterTests(BaseTestCase):
def test_get_instances(self): def test_get_instances(self):
class TestBackend(backends.ServiceBackend): class TestBackend(backends.ServiceController):
verbose_name = 'Route' verbose_name = 'Route'
models = ['routes.Route',] models = ['routes.Route']
choices = get_backend_choices(backends.ServiceBackend.get_backends()) def save(self, instance):
pass
choices = backends.ServiceBackend.get_plugin_choices()
Route._meta.get_field_by_name('backend')[0]._choices = choices Route._meta.get_field_by_name('backend')[0]._choices = choices
backend = TestBackend.get_name() backend = TestBackend.get_name()
route = Route.objects.create(backend=backend, host=self.host, route = Route.objects.create(backend=backend, host=self.host, match='True')
match='True') operation = Operation(backend=TestBackend, instance=route, action='save')
operation = operations.Operation(TestBackend, route, 'commit')
self.assertEqual(1, len(Route.get_servers(operation))) self.assertEqual(1, len(Route.get_servers(operation)))
route = Route.objects.create(backend=backend, host=self.host1, route = Route.objects.create(backend=backend, host=self.host1,
match='instance.backend == "TestBackend"') match='route.backend == "%s"' % TestBackend.get_name())
operation = operations.Operation(TestBackend, route, 'commit')
self.assertEqual(2, len(Route.get_servers(operation))) self.assertEqual(2, len(Route.get_servers(operation)))
route = Route.objects.create(backend=backend, host=self.host2, route = Route.objects.create(backend=backend, host=self.host2,
match='instance.backend == "something else"') match='route.backend == "something else"')
operation = operations.Operation(TestBackend, route, 'commit')
self.assertEqual(2, len(Route.get_servers(operation))) self.assertEqual(2, len(Route.get_servers(operation)))

View file

@ -46,7 +46,9 @@ class ServiceHandler(plugins.Plugin):
def matches(self, instance): def matches(self, instance):
safe_locals = { safe_locals = {
instance._meta.model_name: instance 'instance': instance,
'obj': instance,
instance._meta.model_name: instance,
} }
return eval(self.match, safe_locals) return eval(self.match, safe_locals)

View file

@ -90,6 +90,7 @@ class BaseLiveServerTestCase(AppDependencyMixin, LiveServerTestCase):
def setUp(self): def setUp(self):
super(BaseLiveServerTestCase, self).setUp() super(BaseLiveServerTestCase, self).setUp()
self.rest = Api(self.live_server_url + '/api/') self.rest = Api(self.live_server_url + '/api/')
self.rest.enable_logging()
self.account = self.create_account(superuser=True) self.account = self.create_account(superuser=True)
def admin_login(self): def admin_login(self):

View file

@ -1,5 +1,3 @@
apt-get install postfix
# http://www.postfix.org/VIRTUAL_README.html#virtual_mailbox # http://www.postfix.org/VIRTUAL_README.html#virtual_mailbox
# https://help.ubuntu.com/community/PostfixVirtualMailBoxClamSmtpHowto # https://help.ubuntu.com/community/PostfixVirtualMailBoxClamSmtpHowto
@ -9,15 +7,17 @@ apt-get install postfix
apt-get install dovecot-core dovecot-imapd dovecot-pop3d dovecot-lmtpd dovecot-sieve apt-get install dovecot-core dovecot-imapd dovecot-pop3d dovecot-lmtpd dovecot-sieve
sed -i "s#^mail_location = mbox.*#mail_location = maildir:~/Maildir#" /etc/dovecot/conf.d/10-mail.conf
echo 'auth_username_format = %n' >> /etc/dovecot/conf.d/10-auth.conf echo 'mail_location = maildir:~/Maildir
echo 'service lmtp { mail_plugins = quota
auth_username_format = %n
service lmtp {
unix_listener /var/spool/postfix/private/dovecot-lmtp { unix_listener /var/spool/postfix/private/dovecot-lmtp {
group = postfix group = postfix
mode = 0600 mode = 0600
user = postfix user = postfix
} }
}' >> /etc/dovecot/conf.d/10-master.conf }' > /etc/dovecot/local.conf
cat > /etc/apt/sources.list.d/mailscanner.list << 'EOF' cat > /etc/apt/sources.list.d/mailscanner.list << 'EOF'
@ -38,6 +38,7 @@ echo 'mailbox_transport = lmtp:unix:private/dovecot-lmtp' >> /etc/postfix/main.c
/etc/init.d/dovecot restart /etc/init.d/dovecot restart
/etc/init.d/postfix restart /etc/init.d/postfix restart