django-orchestra-test/orchestra/contrib/orchestration/methods.py
2015-05-11 14:05:39 +00:00

176 lines
6.2 KiB
Python

import hashlib
import json
import logging
import os
import socket
import sys
import select
from celery.datastructures import ExceptionInfo
from django.conf import settings as djsettings
from orchestra.utils.sys import sshrun
from orchestra.utils.python import CaptureStdout, import_class
from . import settings
logger = logging.getLogger(__name__)
paramiko_connections = {}
def Paramiko(backend, log, server, cmds, async=False):
"""
Executes cmds to remote server using Pramaiko
"""
import paramiko
script = '\n'.join(cmds)
script = script.replace('\r', '')
log.state = log.STARTED
log.script = script
log.save(update_fields=('script', 'state', 'updated_at'))
if not cmds:
return
channel = None
ssh = None
try:
addr = server.get_address()
# ssh connection
ssh = paramiko_connections.get(addr)
if not ssh:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
key = settings.ORCHESTRATION_SSH_KEY_PATH
try:
ssh.connect(addr, username='root', key_filename=key)
except socket.error as e:
logger.error('%s timed out on %s' % (backend, addr))
log.state = log.TIMEOUT
log.stderr = str(e)
log.save(update_fields=('state', 'stderr', 'updated_at'))
return
paramiko_connections[addr] = ssh
transport = ssh.get_transport()
channel = transport.open_session()
channel.exec_command(backend.script_executable)
channel.sendall(script)
channel.shutdown_write()
# Log results
logger.debug('%s running on %s' % (backend, server))
if async:
second = False
while True:
# Non-blocking is the secret ingridient in the async sauce
select.select([channel], [], [])
if channel.recv_ready():
part = channel.recv(1024).decode('utf-8')
while part:
log.stdout += part
part = channel.recv(1024).decode('utf-8')
if channel.recv_stderr_ready():
part = channel.recv_stderr(1024).decode('utf-8')
while part:
log.stderr += part
part = channel.recv_stderr(1024).decode('utf-8')
log.save(update_fields=('stdout', 'stderr', 'updated_at'))
if channel.exit_status_ready():
if second:
break
second = True
else:
log.stdout += channel.makefile('rb', -1).read().decode('utf-8')
log.stderr += channel.makefile_stderr('rb', -1).read().decode('utf-8')
log.exit_code = channel.recv_exit_status()
log.state = log.SUCCESS if log.exit_code == 0 else log.FAILURE
logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
log.save()
except:
log.state = log.ERROR
log.traceback = ExceptionInfo(sys.exc_info()).traceback
logger.error('Exception while executing %s on %s' % (backend, server))
logger.debug(log.traceback)
log.save()
finally:
if log.state == log.STARTED:
log.state = log.ABORTED
log.save(update_fields=('state', 'updated_at'))
if channel is not None:
channel.close()
def OpenSSH(backend, log, server, cmds, async=False):
"""
Executes cmds to remote server using SSH with connection resuse for maximum performance
"""
script = '\n'.join(cmds)
script = script.replace('\r', '')
log.state = log.STARTED
log.script = script
log.save(update_fields=('script', 'state', 'updated_at'))
if not cmds:
return
channel = None
ssh = None
try:
ssh = sshrun(server.get_address(), script, executable=backend.script_executable,
persist=True, async=async, silent=True)
logger.debug('%s running on %s' % (backend, server))
if async:
second = False
for state in ssh:
log.stdout += state.stdout.decode('utf8')
log.stderr += state.stderr.decode('utf8')
log.save(update_fields=('stdout', 'stderr', 'updated_at'))
log.exit_code = state.exit_code
else:
log.stdout = ssh.stdout.decode('utf8')
log.stderr = ssh.stderr.decode('utf8')
log.exit_code = ssh.exit_code
log.state = log.SUCCESS if log.exit_code == 0 else log.FAILURE
logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
log.save()
except:
log.state = log.ERROR
log.traceback = ExceptionInfo(sys.exc_info()).traceback
logger.error('Exception while executing %s on %s' % (backend, server))
logger.debug(log.traceback)
log.save()
finally:
if log.state == log.STARTED:
log.state = log.ABORTED
log.save(update_fields=('state', 'updated_at'))
def SSH(*args, **kwargs):
""" facade function enabling to chose between multiple SSH backends"""
method = import_class(settings.ORCHESTRATION_SSH_METHOD_BACKEND)
return method(*args, **kwargs)
def Python(backend, log, server, cmds, async=False):
script = [ str(cmd.func.__name__) + str(cmd.args) for cmd in cmds ]
script = json.dumps(script, indent=4).replace('"', '')
log.state = log.STARTED
log.script = '\n'.join((log.script, script))
log.save(update_fields=('script', 'state', 'updated_at'))
try:
for cmd in cmds:
with CaptureStdout() as stdout:
result = cmd(server)
for line in stdout:
log.stdout += line + '\n'
if async:
log.save(update_fields=('stdout', 'updated_at'))
except:
log.exit_code = 1
log.state = log.FAILURE
log.traceback = ExceptionInfo(sys.exc_info()).traceback
logger.error('Exception while executing %s on %s' % (backend, server))
else:
log.exit_code = 0
log.state = log.SUCCESS
logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
log.save()