diff --git a/passbook/app_gw/requirements.txt b/passbook/app_gw/requirements.txt index ae3eaf219..934e9cc25 100644 --- a/passbook/app_gw/requirements.txt +++ b/passbook/app_gw/requirements.txt @@ -1,2 +1,5 @@ django-revproxy urllib3[secure] +channels +service_identity +websocket-client diff --git a/passbook/app_gw/settings.py b/passbook/app_gw/settings.py index 6e5808d8d..2fabd10ef 100644 --- a/passbook/app_gw/settings.py +++ b/passbook/app_gw/settings.py @@ -1,5 +1,5 @@ """Application Security Gateway settings""" - -# INSTALLED_APPS = [ -# 'revproxy' -# ] +INSTALLED_APPS = [ + 'channels' +] +ASGI_APPLICATION = "passbook.app_gw.websocket.routing.application" diff --git a/passbook/app_gw/websocket/__init__.py b/passbook/app_gw/websocket/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/passbook/app_gw/websocket/consumer.py b/passbook/app_gw/websocket/consumer.py new file mode 100644 index 000000000..bedfa41ad --- /dev/null +++ b/passbook/app_gw/websocket/consumer.py @@ -0,0 +1,83 @@ +"""websocket proxy consumer""" +import threading +from logging import getLogger +from ssl import CERT_NONE + +import websocket +from channels.generic.websocket import WebsocketConsumer + +from passbook.app_gw.models import ApplicationGatewayProvider + +LOGGER = getLogger(__name__) + +class ProxyConsumer(WebsocketConsumer): + """Proxy websocket connection to upstream""" + + _headers_dict = {} + _app_gw = None + _client = None + _thread = None + + def _fix_headers(self, input_dict): + """Fix headers from bytestrings to normal strings""" + return { + key.decode('utf-8'): value.decode('utf-8') + for key, value in dict(input_dict).items() + } + + def connect(self): + """Extract host header, lookup in database and proxy connection""" + self._headers_dict = self._fix_headers(dict(self.scope.get('headers'))) + host = self._headers_dict.pop('host') + query_string = self.scope.get('query_string').decode('utf-8') + matches = ApplicationGatewayProvider.objects.filter( + server_name__contains=[host], + enabled=True) + if matches.exists(): + self._app_gw = matches.first() + # TODO: Get upstream that starts with wss or + upstream = self._app_gw.upstream[0].replace('http', 'ws') + self.scope.get('path') + if query_string: + upstream += '?' + query_string + sslopt = {} + if not self._app_gw.upstream_ssl_verification: + sslopt = {"cert_reqs": CERT_NONE} + self._client = websocket.WebSocketApp( + url=upstream, + subprotocols=self.scope.get('subprotocols'), + header=self._headers_dict, + on_message=self._client_on_message_handler(), + on_error=self._client_on_error_handler(), + on_close=self._client_on_close_handler(), + on_open=self._client_on_open_handler()) + LOGGER.debug("Accepting connection for %s", host) + self._thread = threading.Thread(target=lambda: self._client.run_forever(sslopt=sslopt)) + self._thread.start() + + def _client_on_open_handler(self): + return lambda ws: self.accept(self._client.sock.handshake_response.subprotocol) + + def _client_on_message_handler(self): + # pylint: disable=unused-argument,invalid-name + def message_handler(ws, message): + if isinstance(message, str): + self.send(text_data=message) + else: + self.send(bytes_data=message) + return message_handler + + def _client_on_error_handler(self): + return lambda ws, error: print(error) + + def _client_on_close_handler(self): + return lambda ws: self.disconnect(0) + + def disconnect(self, code): + self._client.close() + + def receive(self, text_data=None, bytes_data=None): + if text_data: + opcode = websocket.ABNF.OPCODE_TEXT + if bytes_data: + opcode = websocket.ABNF.OPCODE_BINARY + self._client.send(text_data or bytes_data, opcode) diff --git a/passbook/app_gw/websocket/routing.py b/passbook/app_gw/websocket/routing.py new file mode 100644 index 000000000..bbf7b7a83 --- /dev/null +++ b/passbook/app_gw/websocket/routing.py @@ -0,0 +1,17 @@ +"""app_gw websocket proxy""" +from channels.auth import AuthMiddlewareStack +from channels.routing import ProtocolTypeRouter, URLRouter +from django.conf.urls import url + +from passbook.app_gw.websocket.consumer import ProxyConsumer + +websocket_urlpatterns = [ + url(r'^(.*)$', ProxyConsumer), +] + +application = ProtocolTypeRouter({ + # (http->django views is added by default) + 'websocket': AuthMiddlewareStack( + URLRouter(websocket_urlpatterns) + ), +})