84 lines
3.0 KiB
Python
84 lines
3.0 KiB
Python
|
"""websocket proxy consumer"""
|
||
|
import threading
|
||
|
from logging import getLogger
|
||
|
from ssl import CERT_NONE
|
||
|
|
||
|
import websocket
|
||
|
from channels.generic.websocket import WebsocketConsumer
|
||
|
|
||
|
from passbook.app_gw.models import ApplicationGatewayProvider
|
||
|
|
||
|
LOGGER = getLogger(__name__)
|
||
|
|
||
|
class ProxyConsumer(WebsocketConsumer):
|
||
|
"""Proxy websocket connection to upstream"""
|
||
|
|
||
|
_headers_dict = {}
|
||
|
_app_gw = None
|
||
|
_client = None
|
||
|
_thread = None
|
||
|
|
||
|
def _fix_headers(self, input_dict):
|
||
|
"""Fix headers from bytestrings to normal strings"""
|
||
|
return {
|
||
|
key.decode('utf-8'): value.decode('utf-8')
|
||
|
for key, value in dict(input_dict).items()
|
||
|
}
|
||
|
|
||
|
def connect(self):
|
||
|
"""Extract host header, lookup in database and proxy connection"""
|
||
|
self._headers_dict = self._fix_headers(dict(self.scope.get('headers')))
|
||
|
host = self._headers_dict.pop('host')
|
||
|
query_string = self.scope.get('query_string').decode('utf-8')
|
||
|
matches = ApplicationGatewayProvider.objects.filter(
|
||
|
server_name__contains=[host],
|
||
|
enabled=True)
|
||
|
if matches.exists():
|
||
|
self._app_gw = matches.first()
|
||
|
# TODO: Get upstream that starts with wss or
|
||
|
upstream = self._app_gw.upstream[0].replace('http', 'ws') + self.scope.get('path')
|
||
|
if query_string:
|
||
|
upstream += '?' + query_string
|
||
|
sslopt = {}
|
||
|
if not self._app_gw.upstream_ssl_verification:
|
||
|
sslopt = {"cert_reqs": CERT_NONE}
|
||
|
self._client = websocket.WebSocketApp(
|
||
|
url=upstream,
|
||
|
subprotocols=self.scope.get('subprotocols'),
|
||
|
header=self._headers_dict,
|
||
|
on_message=self._client_on_message_handler(),
|
||
|
on_error=self._client_on_error_handler(),
|
||
|
on_close=self._client_on_close_handler(),
|
||
|
on_open=self._client_on_open_handler())
|
||
|
LOGGER.debug("Accepting connection for %s", host)
|
||
|
self._thread = threading.Thread(target=lambda: self._client.run_forever(sslopt=sslopt))
|
||
|
self._thread.start()
|
||
|
|
||
|
def _client_on_open_handler(self):
|
||
|
return lambda ws: self.accept(self._client.sock.handshake_response.subprotocol)
|
||
|
|
||
|
def _client_on_message_handler(self):
|
||
|
# pylint: disable=unused-argument,invalid-name
|
||
|
def message_handler(ws, message):
|
||
|
if isinstance(message, str):
|
||
|
self.send(text_data=message)
|
||
|
else:
|
||
|
self.send(bytes_data=message)
|
||
|
return message_handler
|
||
|
|
||
|
def _client_on_error_handler(self):
|
||
|
return lambda ws, error: print(error)
|
||
|
|
||
|
def _client_on_close_handler(self):
|
||
|
return lambda ws: self.disconnect(0)
|
||
|
|
||
|
def disconnect(self, code):
|
||
|
self._client.close()
|
||
|
|
||
|
def receive(self, text_data=None, bytes_data=None):
|
||
|
if text_data:
|
||
|
opcode = websocket.ABNF.OPCODE_TEXT
|
||
|
if bytes_data:
|
||
|
opcode = websocket.ABNF.OPCODE_BINARY
|
||
|
self._client.send(text_data or bytes_data, opcode)
|