import inspect
import logging
import socket
import sys
import select
import textwrap

from celery.datastructures import ExceptionInfo

from orchestra.settings import ORCHESTRA_SSH_DEFAULT_USER
from orchestra.utils.sys import sshrun
from orchestra.utils.python import CaptureStdout, import_class

from . import settings


logger = logging.getLogger(__name__)


def Paramiko(backend, log, server, cmds, async=False, paramiko_connections={}):
    """
    Executes cmds to remote server using Pramaiko
    """
    import paramiko
    script = '\n'.join(cmds)
    script = script.replace('\r', '')
    log.state = log.STARTED
    log.script = script
    log.save(update_fields=('script', 'state', 'updated_at'))
    if not cmds:
        return
    channel = None
    ssh = None
    try:
        addr = server.get_address()
        # ssh connection
        ssh = paramiko_connections.get(addr)
        if not ssh:
            ssh = paramiko.SSHClient()
            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            key = settings.ORCHESTRATION_SSH_KEY_PATH
            try:
                ssh.connect(addr, username=ORCHESTRA_SSH_DEFAULT_USER, key_filename=key)
            except socket.error as e:
                logger.error('%s timed out on %s' % (backend, addr))
                log.state = log.TIMEOUT
                log.stderr = str(e)
                log.save(update_fields=('state', 'stderr', 'updated_at'))
                return
            paramiko_connections[addr] = ssh
        transport = ssh.get_transport()
        channel = transport.open_session()
        channel.exec_command(backend.script_executable)
        channel.sendall(script)
        channel.shutdown_write()
        # Log results
        logger.debug('%s running on %s' % (backend, server))
        if async:
            second = False
            while True:
                # Non-blocking is the secret ingridient in the async sauce
                select.select([channel], [], [])
                if channel.recv_ready():
                    part = channel.recv(1024).decode('utf-8')
                    while part:
                        log.stdout += part
                        part = channel.recv(1024).decode('utf-8')
                if channel.recv_stderr_ready():
                    part = channel.recv_stderr(1024).decode('utf-8')
                    while part:
                        log.stderr += part
                        part = channel.recv_stderr(1024).decode('utf-8')
                log.save(update_fields=('stdout', 'stderr', 'updated_at'))
                if channel.exit_status_ready():
                    if second:
                        break
                    second = True
        else:
            log.stdout += channel.makefile('rb', -1).read().decode('utf-8')
            log.stderr += channel.makefile_stderr('rb', -1).read().decode('utf-8')
        
        log.exit_code = channel.recv_exit_status()
        log.state = log.SUCCESS if log.exit_code == 0 else log.FAILURE
        logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
        log.save()
    except:
        log.state = log.ERROR
        log.traceback = ExceptionInfo(sys.exc_info()).traceback
        logger.error('Exception while executing %s on %s' % (backend, server))
        logger.debug(log.traceback)
        log.save()
    finally:
        if log.state == log.STARTED:
            log.state = log.ABORTED
            log.save(update_fields=('state', 'updated_at'))
        if channel is not None:
            channel.close()


def OpenSSH(backend, log, server, cmds, async=False):
    """
    Executes cmds to remote server using SSH with connection resuse for maximum performance
    """
    script = '\n'.join(cmds)
    script = script.replace('\r', '')
    log.state = log.STARTED
    log.script = '\n'.join((log.script, script))
    log.save(update_fields=('script', 'state', 'updated_at'))
    if not cmds:
        return
    try:
        ssh = sshrun(server.get_address(), script, executable=backend.script_executable,
            persist=True, async=async, silent=True)
        logger.debug('%s running on %s' % (backend, server))
        if async:
            for state in ssh:
                log.stdout += state.stdout.decode('utf8')
                log.stderr += state.stderr.decode('utf8')
                log.save(update_fields=('stdout', 'stderr', 'updated_at'))
            exit_code = state.exit_code
        else:
            log.stdout += ssh.stdout.decode('utf8')
            log.stderr += ssh.stderr.decode('utf8')
            exit_code = ssh.exit_code
        if not log.exit_code:
            log.exit_code = exit_code
            if exit_code == 255 and log.stderr.startswith('ssh: connect to host'):
                log.state = log.TIMEOUT
            else:
                log.state = log.SUCCESS if exit_code == 0 else log.FAILURE
        logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
        log.save()
    except:
        log.state = log.ERROR
        log.traceback = ExceptionInfo(sys.exc_info()).traceback
        logger.error('Exception while executing %s on %s' % (backend, server))
        logger.debug(log.traceback)
        log.save()
    finally:
        if log.state == log.STARTED:
            log.state = log.ABORTED
            log.save(update_fields=('state', 'updated_at'))


def SSH(*args, **kwargs):
    """ facade function enabling to chose between multiple SSH backends"""
    method = import_class(settings.ORCHESTRATION_SSH_METHOD_BACKEND)
    return method(*args, **kwargs)


def Python(backend, log, server, cmds, async=False):
    script = ''
    functions = set()
    for cmd in cmds:
        if cmd.func not in functions:
            functions.add(cmd.func)
            script += textwrap.dedent(''.join(inspect.getsourcelines(cmd.func)[0]))
    script += '\n'
    for cmd in cmds:
        script += '# %s %s\n' % (cmd.func.__name__, cmd.args)
    log.state = log.STARTED
    log.script = '\n'.join((log.script, script))
    log.save(update_fields=('script', 'state', 'updated_at'))
    stdout = ''
    try:
        for cmd in cmds:
            with CaptureStdout() as stdout:
                result = cmd(server)
            for line in stdout:
                log.stdout += line + '\n'
            if result:
                log.stdout += '# Result: %s\n' % result
            if async:
                log.save(update_fields=('stdout', 'updated_at'))
    except:
        log.exit_code = 1
        log.state = log.FAILURE
        log.stdout += '\n'.join(stdout)
        log.traceback += ExceptionInfo(sys.exc_info()).traceback
        logger.error('Exception while executing %s on %s' % (backend, server))
    else:
        if not log.exit_code:
            log.exit_code = 0
            log.state = log.SUCCESS
        logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
    log.save()