"""Enterprise models"""
from base64 import b64decode
from binascii import Error
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from uuid import uuid4
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict
from django.contrib.postgres.indexes import HashIndex
from django.db import models
from django.db.models.query import QuerySet
from django.utils.timezone import now
from django.utils.translation import gettext as _
from guardian.shortcuts import get_anonymous_user
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.serializers import BaseSerializer
from authentik.core.models import ExpiringModel, User, UserTypes
from authentik.lib.models import SerializerModel
from authentik.root.install_id import get_install_id
def get_licensing_key() -> Certificate:
"""Get Root CA PEM"""
with open("authentik/enterprise/public.pem", "rb") as _key:
return load_pem_x509_certificate(
def get_license_aud() -> str:
"""Get the JWT audience field"""
return f"{get_install_id()}"
class LicenseFlags(Enum):
"""License flags"""
class LicenseKey:
"""License JWT claims"""
aud: str
exp: int
name: str
internal_users: int = 0
external_users: int = 0
flags: list[LicenseFlags] = field(default_factory=list)
def validate(jwt: str) -> "LicenseKey":
"""Validate the license from a given JWT"""
headers = get_unverified_header(jwt)
except PyJWTError:
raise ValidationError("Unable to verify license")
x5c: list[str] = headers.get("x5c", [])
if len(x5c) < 1:
raise ValidationError("Unable to verify license")
our_cert = load_der_x509_certificate(b64decode(x5c[0]))
intermediate = load_der_x509_certificate(b64decode(x5c[1]))
except (InvalidSignature, TypeError, ValueError, Error):
raise ValidationError("Unable to verify license")
body = from_dict(
except PyJWTError:
raise ValidationError("Unable to verify license")
return body
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
if total.exp == 0:
total.exp = exp_ts
if exp_ts <= total.exp:
total.exp = exp_ts
return total
def base_user_qs() -> QuerySet:
"""Base query set for all users"""
return User.objects.all().exclude(pk=get_anonymous_user().pk)
def get_default_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
def get_external_user_count():
"""Get current external user count"""
# Count since start of the month
last_month = now().replace(day=1)
return (
.filter(type=UserTypes.EXTERNAL, last_login__gte=last_month)
def is_valid(self) -> bool:
"""Check if the given license body covers all users
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if LicenseUsage.objects.filter(record_date__gte=threshold).exists():
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
if not usage:
return now()
return usage.record_date
class License(SerializerModel):
"""An authentik enterprise license"""
license_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
key = models.TextField()
name = models.TextField()
expiry = models.DateTimeField()
internal_users = models.BigIntegerField()
external_users = models.BigIntegerField()
def serializer(self) -> type[BaseSerializer]:
from authentik.enterprise.api import LicenseSerializer
return LicenseSerializer
def status(self) -> LicenseKey:
"""Get parsed license status"""
return LicenseKey.validate(self.key)
class Meta:
indexes = (HashIndex(fields=("key",)),)
verbose_name = _("License")
verbose_name_plural = _("Licenses")
def usage_expiry():
"""Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
class LicenseUsage(ExpiringModel):
"""a single license usage record"""
expires = models.DateTimeField(default=usage_expiry)
usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
user_count = models.BigIntegerField()
external_user_count = models.BigIntegerField()
within_limits = models.BooleanField()
record_date = models.DateTimeField(auto_now_add=True)
class Meta:
verbose_name = _("License Usage")
verbose_name_plural = _("License Usage Records")