get enterprise token
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
938f6fe439
commit
cf93445b3f
|
@ -88,7 +88,7 @@ class LicenseKey:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_total() -> "LicenseKey":
|
def get_total() -> "LicenseKey":
|
||||||
"""Get a summarized version of all (not expired) licenses"""
|
"""Get a summarized version of all (not expired) licenses"""
|
||||||
active_licenses = License.objects.filter(expiry__gte=now())
|
active_licenses = License.non_expired()
|
||||||
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
|
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
|
||||||
for lic in active_licenses:
|
for lic in active_licenses:
|
||||||
total.internal_users += lic.internal_users
|
total.internal_users += lic.internal_users
|
||||||
|
@ -167,6 +167,10 @@ class License(SerializerModel):
|
||||||
internal_users = models.BigIntegerField()
|
internal_users = models.BigIntegerField()
|
||||||
external_users = models.BigIntegerField()
|
external_users = models.BigIntegerField()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def non_expired(cls) -> QuerySet["License"]:
|
||||||
|
return License.objects.filter(expiry__gte=now())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def serializer(self) -> type[BaseSerializer]:
|
def serializer(self) -> type[BaseSerializer]:
|
||||||
from authentik.enterprise.api import LicenseSerializer
|
from authentik.enterprise.api import LicenseSerializer
|
||||||
|
|
|
@ -8,6 +8,8 @@ from grpc import (
|
||||||
UnaryUnaryClientInterceptor,
|
UnaryUnaryClientInterceptor,
|
||||||
insecure_channel,
|
insecure_channel,
|
||||||
intercept_channel,
|
intercept_channel,
|
||||||
|
ssl_channel_credentials,
|
||||||
|
secure_channel,
|
||||||
)
|
)
|
||||||
from grpc._interceptor import _ClientCallDetails
|
from grpc._interceptor import _ClientCallDetails
|
||||||
|
|
||||||
|
@ -48,12 +50,28 @@ class AuthInterceptor(UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor)
|
||||||
return continuation(self._intercept_client_call_details(client_call_details), request)
|
return continuation(self._intercept_client_call_details(client_call_details), request)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_enterprise_token() -> str:
|
||||||
|
"""Get enterprise license key, if a license is installed, otherwise use the install ID"""
|
||||||
|
from authentik.root.install_id import get_install_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
from authentik.enterprise.models import License
|
||||||
|
|
||||||
|
license = License.non_expired().order_by("-expiry").first()
|
||||||
|
if not license:
|
||||||
|
return get_install_id()
|
||||||
|
return license.key
|
||||||
|
except ImportError:
|
||||||
|
return get_install_id()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_client(addr: str):
|
def get_client(addr: str):
|
||||||
"""get a cached client to a cloud-gateway"""
|
"""get a cached client to a cloud-gateway"""
|
||||||
target = addr
|
channel = secure_channel(addr, ssl_channel_credentials)
|
||||||
if settings.DEBUG:
|
if settings.DEBUG:
|
||||||
target = insecure_channel(target)
|
channel = insecure_channel(addr)
|
||||||
channel = intercept_channel(target, AuthInterceptor("foo"))
|
channel = intercept_channel(addr, AuthInterceptor(get_enterprise_token()))
|
||||||
client = AuthenticationPushStub(channel)
|
client = AuthenticationPushStub(channel)
|
||||||
return client
|
return client
|
||||||
|
|
Reference in New Issue