django-orchestra/orchestra/contrib/orchestration/methods.py

160 lines
5.6 KiB
Python

import hashlib
import json
import logging
import os
import socket
import sys
import select
import paramiko
from celery.datastructures import ExceptionInfo
from django.conf import settings as djsettings
from orchestra.utils.python import CaptureStdout
from . import settings
logger = logging.getLogger(__name__)
transports = {}
def SSH(backend, log, server, cmds, async=False):
"""
Executes cmds to remote server using SSH
The script is first copied using SCP in order to overflood the channel with large scripts
Then the script is executed using the defined backend.script_executable
"""
script = '\n'.join(cmds)
script = script.replace('\r', '')
bscript = script.encode('utf-8')
digest = hashlib.md5(bscript).hexdigest()
path = os.path.join(settings.ORCHESTRATION_TEMP_SCRIPT_DIR, digest)
remote_path = "%s.remote" % path
# Ensure unique local paths for each file because of problems when os.remove(path)
path += '@%s' % str(server)
log.state = log.STARTED
log.script = '# %s\n%s' % (remote_path, script)
log.save(update_fields=('script', 'state'))
if not cmds:
return
channel = None
ssh = None
try:
# Avoid "Argument list too long" on large scripts by genereting a file
# and scping it to the remote server
with os.fdopen(os.open(path, os.O_WRONLY | os.O_CREAT, 0o600), 'wb') as handle:
handle.write(bscript)
# ssh connection
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
addr = server.get_address()
key = settings.ORCHESTRATION_SSH_KEY_PATH
try:
ssh.connect(addr, username='root', key_filename=key, timeout=10)
except socket.error as e:
logger.error('%s timed out on %s' % (backend, addr))
log.state = log.TIMEOUT
log.stderr = str(e)
log.save(update_fields=['state', 'stderr'])
return
transport = ssh.get_transport()
# Copy script to remote server
sftp = paramiko.SFTPClient.from_transport(transport)
sftp.put(path, remote_path)
sftp.chmod(remote_path, 0o600)
sftp.close()
os.remove(path)
# Execute it
context = {
'executable': backend.script_executable,
'remote_path': remote_path,
'digest': digest,
'remove': '' if djsettings.DEBUG else "rm -fr %(remote_path)s\n",
}
cmd = (
"[[ $(md5sum %(remote_path)s|awk {'print $1'}) == %(digest)s ]] && %(executable)s %(remote_path)s\n"
"RETURN_CODE=$?\n"
"%(remove)s"
"exit $RETURN_CODE" % context
)
channel = transport.open_session()
channel.exec_command(cmd)
# Log results
logger.debug('%s running on %s' % (backend, server))
if async:
second = False
while True:
# Non-blocking is the secret ingridient in the async sauce
select.select([channel], [], [])
if channel.recv_ready():
part = channel.recv(1024).decode('utf-8')
while part:
log.stdout += part
part = channel.recv(1024).decode('utf-8')
if channel.recv_stderr_ready():
part = channel.recv_stderr(1024).decode('utf-8')
while part:
log.stderr += part
part = channel.recv_stderr(1024).decode('utf-8')
log.save(update_fields=['stdout', 'stderr'])
if channel.exit_status_ready():
if second:
break
second = True
else:
log.stdout += channel.makefile('rb', -1).read().decode('utf-8')
log.stderr += channel.makefile_stderr('rb', -1).read().decode('utf-8')
log.exit_code = exit_code = channel.recv_exit_status()
log.state = log.SUCCESS if exit_code == 0 else log.FAILURE
logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
log.save()
except:
log.state = log.ERROR
log.traceback = ExceptionInfo(sys.exc_info()).traceback
logger.error('Exception while executing %s on %s' % (backend, server))
logger.debug(log.traceback)
log.save()
finally:
if log.state == log.STARTED:
log.state = log.ABORTED
log.save(update_fields=['state'])
if channel is not None:
channel.close()
if ssh is not None:
ssh.close()
def Python(backend, log, server, cmds, async=False):
# TODO collect stdout?
script = [ str(cmd.func.__name__) + str(cmd.args) for cmd in cmds ]
script = json.dumps(script, indent=4).replace('"', '')
log.state = log.STARTED
log.script = '\n'.join([log.script, script])
log.save(update_fields=('script', 'state'))
try:
for cmd in cmds:
with CaptureStdout() as stdout:
result = cmd(server)
for line in stdout:
log.stdout += line + '\n'
if async:
log.save(update_fields=['stdout'])
except:
log.exit_code = 1
log.state = log.FAILURE
log.traceback = ExceptionInfo(sys.exc_info()).traceback
logger.error('Exception while executing %s on %s' % (backend, server))
else:
log.exit_code = 0
log.state = log.SUCCESS
logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
log.save()