diff --git a/authentik/enterprise/models.py b/authentik/enterprise/models.py index fb9d8367d..561fbba7c 100644 --- a/authentik/enterprise/models.py +++ b/authentik/enterprise/models.py @@ -88,7 +88,7 @@ class LicenseKey: @staticmethod def get_total() -> "LicenseKey": """Get a summarized version of all (not expired) licenses""" - active_licenses = License.objects.filter(expiry__gte=now()) + active_licenses = License.non_expired() total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) for lic in active_licenses: total.internal_users += lic.internal_users @@ -167,6 +167,10 @@ class License(SerializerModel): internal_users = models.BigIntegerField() external_users = models.BigIntegerField() + @classmethod + def non_expired(cls) -> QuerySet["License"]: + return License.objects.filter(expiry__gte=now()) + @property def serializer(self) -> type[BaseSerializer]: from authentik.enterprise.api import LicenseSerializer diff --git a/authentik/stages/authenticator_mobile/cloud_gateway.py b/authentik/stages/authenticator_mobile/cloud_gateway.py index 9bbb8197e..b3552a6d0 100644 --- a/authentik/stages/authenticator_mobile/cloud_gateway.py +++ b/authentik/stages/authenticator_mobile/cloud_gateway.py @@ -8,6 +8,8 @@ from grpc import ( UnaryUnaryClientInterceptor, insecure_channel, intercept_channel, + ssl_channel_credentials, + secure_channel, ) from grpc._interceptor import _ClientCallDetails @@ -48,12 +50,28 @@ class AuthInterceptor(UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor) return continuation(self._intercept_client_call_details(client_call_details), request) +@lru_cache() +def get_enterprise_token() -> str: + """Get enterprise license key, if a license is installed, otherwise use the install ID""" + from authentik.root.install_id import get_install_id + + try: + from authentik.enterprise.models import License + + license = License.non_expired().order_by("-expiry").first() + if not license: + return get_install_id() + return license.key + except ImportError: + return get_install_id() + + @lru_cache() def get_client(addr: str): """get a cached client to a cloud-gateway""" - target = addr + channel = secure_channel(addr, ssl_channel_credentials) if settings.DEBUG: - target = insecure_channel(target) - channel = intercept_channel(target, AuthInterceptor("foo")) + channel = insecure_channel(addr) + channel = intercept_channel(addr, AuthInterceptor(get_enterprise_token())) client = AuthenticationPushStub(channel) return client