root: partial Live-updating config (#5959)

* stages/email: directly use email credentials from config

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use custom database backend that supports dynamic credentials

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add crude config reloader

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make method names for CONFIG clearer

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* replace config.set with environ

Not sure if this is the cleanest way, but it persists through a config reload

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* re-add set for @patch

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* even more crudeness

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* clean up some old stuff?

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* somewhat rewrite config loader to keep track of a source of an attribute so we can refresh it

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* cleanup old things

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix flow e2e

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-07-19 23:13:22 +02:00 committed by GitHub
parent fb4e4dc8db
commit 2f469d2709
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
44 changed files with 260 additions and 184 deletions

View file

@ -58,7 +58,7 @@ def clear_update_notifications():
@prefill_task
def update_latest_version(self: MonitoredTask):
"""Update latest version info"""
if CONFIG.y_bool("disable_update_check"):
if CONFIG.get_bool("disable_update_check"):
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."]))
return

View file

@ -70,7 +70,7 @@ class ConfigView(APIView):
caps.append(Capabilities.CAN_SAVE_MEDIA)
if GEOIP_READER.enabled:
caps.append(Capabilities.CAN_GEO_IP)
if CONFIG.y_bool("impersonation"):
if CONFIG.get_bool("impersonation"):
caps.append(Capabilities.CAN_IMPERSONATE)
if settings.DEBUG: # pragma: no cover
caps.append(Capabilities.CAN_DEBUG)
@ -86,17 +86,17 @@ class ConfigView(APIView):
return ConfigSerializer(
{
"error_reporting": {
"enabled": CONFIG.y("error_reporting.enabled"),
"sentry_dsn": CONFIG.y("error_reporting.sentry_dsn"),
"environment": CONFIG.y("error_reporting.environment"),
"send_pii": CONFIG.y("error_reporting.send_pii"),
"traces_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.4)),
"enabled": CONFIG.get("error_reporting.enabled"),
"sentry_dsn": CONFIG.get("error_reporting.sentry_dsn"),
"environment": CONFIG.get("error_reporting.environment"),
"send_pii": CONFIG.get("error_reporting.send_pii"),
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
},
"capabilities": self.get_capabilities(),
"cache_timeout": int(CONFIG.y("redis.cache_timeout")),
"cache_timeout_flows": int(CONFIG.y("redis.cache_timeout_flows")),
"cache_timeout_policies": int(CONFIG.y("redis.cache_timeout_policies")),
"cache_timeout_reputation": int(CONFIG.y("redis.cache_timeout_reputation")),
"cache_timeout": int(CONFIG.get("redis.cache_timeout")),
"cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
"cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
"cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
}
)

View file

@ -30,7 +30,7 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
return
blueprint_file.seek(0)
instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first()
rel_path = path.relative_to(Path(CONFIG.y("blueprints_dir")))
rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir")))
meta = None
if metadata:
meta = from_dict(BlueprintMetadata, metadata)
@ -55,7 +55,7 @@ def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
Flow = apps.get_model("authentik_flows", "Flow")
db_alias = schema_editor.connection.alias
for file in glob(f"{CONFIG.y('blueprints_dir')}/**/*.yaml", recursive=True):
for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True):
check_blueprint_v1_file(BlueprintInstance, Path(file))
for blueprint in BlueprintInstance.objects.using(db_alias).all():

View file

@ -82,7 +82,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
def retrieve_file(self) -> str:
"""Get blueprint from path"""
try:
base = Path(CONFIG.y("blueprints_dir"))
base = Path(CONFIG.get("blueprints_dir"))
full_path = base.joinpath(Path(self.path)).resolve()
if not str(full_path).startswith(str(base.resolve())):
raise BlueprintRetrievalFailed("Invalid blueprint path")

View file

@ -62,7 +62,7 @@ def start_blueprint_watcher():
if _file_watcher_started:
return
observer = Observer()
observer.schedule(BlueprintEventHandler(), CONFIG.y("blueprints_dir"), recursive=True)
observer.schedule(BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True)
observer.start()
_file_watcher_started = True
@ -80,7 +80,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
blueprints_discovery.delay()
if isinstance(event, FileModifiedEvent):
path = Path(event.src_path)
root = Path(CONFIG.y("blueprints_dir")).absolute()
root = Path(CONFIG.get("blueprints_dir")).absolute()
rel_path = str(path.relative_to(root))
for instance in BlueprintInstance.objects.filter(path=rel_path):
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
@ -101,7 +101,7 @@ def blueprints_find_dict():
def blueprints_find():
"""Find blueprints and return valid ones"""
blueprints = []
root = Path(CONFIG.y("blueprints_dir"))
root = Path(CONFIG.get("blueprints_dir"))
for path in root.rglob("**/*.yaml"):
# Check if any part in the path starts with a dot and assume a hidden file
if any(part for part in path.parts if part.startswith(".")):

View file

@ -596,7 +596,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
@action(detail=True, methods=["POST"])
def impersonate(self, request: Request, pk: int) -> Response:
"""Impersonate a user"""
if not CONFIG.y_bool("impersonation"):
if not CONFIG.get_bool("impersonation"):
LOGGER.debug("User attempted to impersonate", user=request.user)
return Response(status=401)
if not request.user.has_perm("impersonate"):

View file

@ -18,7 +18,7 @@ class Command(BaseCommand):
def handle(self, **options):
close_old_connections()
if CONFIG.y_bool("remote_debug"):
if CONFIG.get_bool("remote_debug"):
import debugpy
debugpy.listen(("0.0.0.0", 6900)) # nosec

View file

@ -60,7 +60,7 @@ def default_token_key():
"""Default token key"""
# We use generate_id since the chars in the key should be easy
# to use in Emails (for verification) and URLs (for recovery)
return generate_id(int(CONFIG.y("default_token_length")))
return generate_id(int(CONFIG.get("default_token_length")))
class UserTypes(models.TextChoices):

View file

@ -46,7 +46,7 @@ def certificate_discovery(self: MonitoredTask):
certs = {}
private_keys = {}
discovered = 0
for file in glob(CONFIG.y("cert_discovery_dir") + "/**", recursive=True):
for file in glob(CONFIG.get("cert_discovery_dir") + "/**", recursive=True):
path = Path(file)
if not path.exists():
continue

View file

@ -33,7 +33,7 @@ class GeoIPReader:
def __open(self):
"""Get GeoIP Reader, if configured, otherwise none"""
path = CONFIG.y("geoip")
path = CONFIG.get("geoip")
if path == "" or not path:
return
try:
@ -46,7 +46,7 @@ class GeoIPReader:
def __check_expired(self):
"""Check if the modification date of the GeoIP database has
changed, and reload it if so"""
path = CONFIG.y("geoip")
path = CONFIG.get("geoip")
try:
mtime = stat(path).st_mtime
diff = self.__last_mtime < mtime

View file

@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
# was restored.
PLAN_CONTEXT_IS_RESTORED = "is_restored"
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_flows"))
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
CACHE_PREFIX = "goauthentik.io/flows/planner/"

View file

@ -18,7 +18,6 @@ from authentik.flows.planner import FlowPlan, FlowPlanner
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
from authentik.lib.config import CONFIG
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding
@ -85,7 +84,6 @@ class TestFlowExecutor(FlowTestCase):
FlowDesignation.AUTHENTICATION,
)
CONFIG.update_from_dict({"domain": "testserver"})
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
@ -111,7 +109,6 @@ class TestFlowExecutor(FlowTestCase):
denied_action=FlowDeniedAction.CONTINUE,
)
CONFIG.update_from_dict({"domain": "testserver"})
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
@ -128,7 +125,6 @@ class TestFlowExecutor(FlowTestCase):
FlowDesignation.AUTHENTICATION,
)
CONFIG.update_from_dict({"domain": "testserver"})
dest = "/unique-string"
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}")
@ -145,7 +141,6 @@ class TestFlowExecutor(FlowTestCase):
FlowDesignation.AUTHENTICATION,
)
CONFIG.update_from_dict({"domain": "testserver"})
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)

View file

@ -175,7 +175,7 @@ def get_avatar(user: "User") -> str:
"initials": avatar_mode_generated,
"gravatar": avatar_mode_gravatar,
}
modes: str = CONFIG.y("avatars", "none")
modes: str = CONFIG.get("avatars", "none")
for mode in modes.split(","):
avatar = None
if mode in mode_map:

View file

@ -2,13 +2,15 @@
import os
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from glob import glob
from json import dumps, loads
from json import JSONEncoder, dumps, loads
from json.decoder import JSONDecodeError
from pathlib import Path
from sys import argv, stderr
from time import time
from typing import Any
from typing import Any, Optional
from urllib.parse import urlparse
import yaml
@ -32,15 +34,44 @@ def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
return root
@dataclass
class Attr:
"""Single configuration attribute"""
class Source(Enum):
"""Sources a configuration attribute can come from, determines what should be done with
Attr.source (and if it's set at all)"""
UNSPECIFIED = "unspecified"
ENV = "env"
CONFIG_FILE = "config_file"
URI = "uri"
value: Any
source_type: Source = field(default=Source.UNSPECIFIED)
# depending on source_type, might contain the environment variable or the path
# to the config file containing this change or the file containing this value
source: Optional[str] = field(default=None)
class AttrEncoder(JSONEncoder):
"""JSON encoder that can deal with `Attr` classes"""
def default(self, o: Any) -> Any:
if isinstance(o, Attr):
return o.value
return super().default(o)
class ConfigLoader:
"""Search through SEARCH_PATHS and load configuration. Environment variables starting with
`ENV_PREFIX` are also applied.
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
loaded_file = []
def __init__(self):
def __init__(self, **kwargs):
super().__init__()
self.__config = {}
base_dir = Path(__file__).parent.joinpath(Path("../..")).resolve()
@ -65,6 +96,7 @@ class ConfigLoader:
# Update config with env file
self.update_from_file(env_file)
self.update_from_env()
self.update(self.__config, kwargs)
def log(self, level: str, message: str, **kwargs):
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when
@ -86,22 +118,32 @@ class ConfigLoader:
else:
if isinstance(value, str):
value = self.parse_uri(value)
elif not isinstance(value, Attr):
value = Attr(value)
root[key] = value
return root
def parse_uri(self, value: str) -> str:
def refresh(self, key: str):
"""Update a single value"""
attr: Attr = get_path_from_dict(self.raw, key)
if attr.source_type != Attr.Source.URI:
return
attr.value = self.parse_uri(attr.source).value
def parse_uri(self, value: str) -> Attr:
"""Parse string values which start with a URI"""
url = urlparse(value)
parsed_value = value
if url.scheme == "env":
value = os.getenv(url.netloc, url.query)
parsed_value = os.getenv(url.netloc, url.query)
if url.scheme == "file":
try:
with open(url.path, "r", encoding="utf8") as _file:
value = _file.read().strip()
parsed_value = _file.read().strip()
except OSError as exc:
self.log("error", f"Failed to read config value from {url.path}: {exc}")
value = url.query
return value
parsed_value = url.query
return Attr(parsed_value, Attr.Source.URI, value)
def update_from_file(self, path: Path):
"""Update config from file contents"""
@ -110,7 +152,6 @@ class ConfigLoader:
try:
self.update(self.__config, yaml.safe_load(file))
self.log("debug", "Loaded config", file=str(path))
self.loaded_file.append(path)
except yaml.YAMLError as exc:
raise ImproperlyConfigured from exc
except PermissionError as exc:
@ -121,10 +162,6 @@ class ConfigLoader:
error=str(exc),
)
def update_from_dict(self, update: dict):
"""Update config from dict"""
self.__config.update(update)
def update_from_env(self):
"""Check environment variables"""
outer = {}
@ -145,7 +182,7 @@ class ConfigLoader:
value = loads(value)
except JSONDecodeError:
pass
current_obj[dot_parts[-1]] = value
current_obj[dot_parts[-1]] = Attr(value, Attr.Source.ENV, key)
idx += 1
if idx > 0:
self.log("debug", "Loaded environment variables", count=idx)
@ -154,28 +191,32 @@ class ConfigLoader:
@contextmanager
def patch(self, path: str, value: Any):
"""Context manager for unittests to patch a value"""
original_value = self.y(path)
self.y_set(path, value)
original_value = self.get(path)
self.set(path, value)
try:
yield
finally:
self.y_set(path, original_value)
self.set(path, original_value)
@property
def raw(self) -> dict:
"""Get raw config dictionary"""
return self.__config
# pylint: disable=invalid-name
def y(self, path: str, default=None, sep=".") -> Any:
def get(self, path: str, default=None, sep=".") -> Any:
"""Access attribute by using yaml path"""
# Walk sub_dicts before parsing path
root = self.raw
# Walk each component of the path
return get_path_from_dict(root, path, sep=sep, default=default)
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
return attr.value
def y_set(self, path: str, value: Any, sep="."):
"""Set value using same syntax as y()"""
def get_bool(self, path: str, default=False) -> bool:
"""Wrapper for get that converts value into boolean"""
return str(self.get(path, default)).lower() == "true"
def set(self, path: str, value: Any, sep="."):
"""Set value using same syntax as get()"""
# Walk sub_dicts before parsing path
root = self.raw
# Walk each component of the path
@ -184,17 +225,14 @@ class ConfigLoader:
if comp not in root:
root[comp] = {}
root = root.get(comp, {})
root[path_parts[-1]] = value
def y_bool(self, path: str, default=False) -> bool:
"""Wrapper for y that converts value into boolean"""
return str(self.y(path, default)).lower() == "true"
root[path_parts[-1]] = Attr(value)
CONFIG = ConfigLoader()
if __name__ == "__main__":
if len(argv) < 2:
print(dumps(CONFIG.raw, indent=4))
print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder))
else:
print(CONFIG.y(argv[1]))
print(CONFIG.get(argv[1]))

View file

@ -51,18 +51,18 @@ class SentryTransport(HttpTransport):
def sentry_init(**sentry_init_kwargs):
"""Configure sentry SDK"""
sentry_env = CONFIG.y("error_reporting.environment", "customer")
sentry_env = CONFIG.get("error_reporting.environment", "customer")
kwargs = {
"environment": sentry_env,
"send_default_pii": CONFIG.y_bool("error_reporting.send_pii", False),
"send_default_pii": CONFIG.get_bool("error_reporting.send_pii", False),
"_experiments": {
"profiles_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.1)),
"profiles_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.1)),
},
}
kwargs.update(**sentry_init_kwargs)
# pylint: disable=abstract-class-instantiated
sentry_sdk_init(
dsn=CONFIG.y("error_reporting.sentry_dsn"),
dsn=CONFIG.get("error_reporting.sentry_dsn"),
integrations=[
ArgvIntegration(),
StdlibIntegration(),
@ -92,7 +92,7 @@ def traces_sampler(sampling_context: dict) -> float:
return 0
if _type == "websocket":
return 0
return float(CONFIG.y("error_reporting.sample_rate", 0.1))
return float(CONFIG.get("error_reporting.sample_rate", 0.1))
def before_send(event: dict, hint: dict) -> Optional[dict]:

View file

@ -16,23 +16,23 @@ class TestConfig(TestCase):
config = ConfigLoader()
environ[ENV_PREFIX + "_test__test"] = "bar"
config.update_from_env()
self.assertEqual(config.y("test.test"), "bar")
self.assertEqual(config.get("test.test"), "bar")
def test_patch(self):
"""Test patch decorator"""
config = ConfigLoader()
config.y_set("foo.bar", "bar")
self.assertEqual(config.y("foo.bar"), "bar")
config.set("foo.bar", "bar")
self.assertEqual(config.get("foo.bar"), "bar")
with config.patch("foo.bar", "baz"):
self.assertEqual(config.y("foo.bar"), "baz")
self.assertEqual(config.y("foo.bar"), "bar")
self.assertEqual(config.get("foo.bar"), "baz")
self.assertEqual(config.get("foo.bar"), "bar")
def test_uri_env(self):
"""Test URI parsing (environment)"""
config = ConfigLoader()
environ["foo"] = "bar"
self.assertEqual(config.parse_uri("env://foo"), "bar")
self.assertEqual(config.parse_uri("env://foo?bar"), "bar")
self.assertEqual(config.parse_uri("env://foo").value, "bar")
self.assertEqual(config.parse_uri("env://foo?bar").value, "bar")
def test_uri_file(self):
"""Test URI parsing (file load)"""
@ -41,11 +41,25 @@ class TestConfig(TestCase):
write(file, "foo".encode())
_, file2_name = mkstemp()
chmod(file2_name, 0o000) # Remove all permissions so we can't read the file
self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo")
self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def")
self.assertEqual(config.parse_uri(f"file://{file_name}").value, "foo")
self.assertEqual(config.parse_uri(f"file://{file2_name}?def").value, "def")
unlink(file_name)
unlink(file2_name)
def test_uri_file_update(self):
"""Test URI parsing (file load and update)"""
file, file_name = mkstemp()
write(file, "foo".encode())
config = ConfigLoader(file_test=f"file://{file_name}")
self.assertEqual(config.get("file_test"), "foo")
# Update config file
write(file, "bar".encode())
config.refresh("file_test")
self.assertEqual(config.get("file_test"), "foobar")
unlink(file_name)
def test_file_update(self):
"""Test update_from_file"""
config = ConfigLoader()

View file

@ -50,7 +50,7 @@ def get_env() -> str:
"""Get environment in which authentik is currently running"""
if "CI" in os.environ:
return "ci"
if CONFIG.y_bool("debug"):
if CONFIG.get_bool("debug"):
return "dev"
if SERVICE_HOST_ENV_NAME in os.environ:
return "kubernetes"

View file

@ -97,7 +97,7 @@ class BaseController:
if self.outpost.config.container_image is not None:
return self.outpost.config.container_image
image_name_template: str = CONFIG.y("outposts.container_image_base")
image_name_template: str = CONFIG.get("outposts.container_image_base")
return image_name_template % {
"type": self.outpost.type,
"version": __version__,

View file

@ -58,7 +58,7 @@ class OutpostConfig:
authentik_host_insecure: bool = False
authentik_host_browser: str = ""
log_level: str = CONFIG.y("log_level")
log_level: str = CONFIG.get("log_level")
object_naming_template: str = field(default="ak-outpost-%(name)s")
container_image: Optional[str] = field(default=None)

View file

@ -256,7 +256,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
def outpost_connection_discovery(self: MonitoredTask):
"""Checks the local environment and create Service connections."""
status = TaskResult(TaskResultStatus.SUCCESSFUL)
if not CONFIG.y_bool("outposts.discover"):
if not CONFIG.get_bool("outposts.discover"):
status.messages.append("Outpost integration discovery is disabled")
self.set_status(status)
return

View file

@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
LOGGER = get_logger()
FORK_CTX = get_context("fork")
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_policies"))
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
PROCESS_CLASS = FORK_CTX.Process

View file

@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
from authentik.stages.identification.signals import identification_failed
LOGGER = get_logger()
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_reputation"))
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
def update_score(request: HttpRequest, identifier: str, amount: int):

View file

@ -46,7 +46,7 @@ class DeviceView(View):
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
throttle = AnonRateThrottle()
throttle.rate = CONFIG.y("throttle.providers.oauth2.device", "20/hour")
throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
if not throttle.allow_request(request, self):
return HttpResponse(status=429)

View file

15
authentik/root/db/base.py Normal file
View file

@ -0,0 +1,15 @@
"""authentik database backend"""
from django_prometheus.db.backends.postgresql.base import DatabaseWrapper as BaseDatabaseWrapper
from authentik.lib.config import CONFIG
class DatabaseWrapper(BaseDatabaseWrapper):
"""database backend which supports rotating credentials"""
def get_connection_params(self):
CONFIG.refresh("postgresql.password")
conn_params = super().get_connection_params()
conn_params["user"] = CONFIG.get("postgresql.user")
conn_params["password"] = CONFIG.get("postgresql.password")
return conn_params

View file

@ -26,15 +26,15 @@ def get_install_id_raw():
"""Get install_id without django loaded, this is required for the startup when we get
the install_id but django isn't loaded yet and we can't use the function above."""
conn = connect(
dbname=CONFIG.y("postgresql.name"),
user=CONFIG.y("postgresql.user"),
password=CONFIG.y("postgresql.password"),
host=CONFIG.y("postgresql.host"),
port=int(CONFIG.y("postgresql.port")),
sslmode=CONFIG.y("postgresql.sslmode"),
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
sslcert=CONFIG.y("postgresql.sslcert"),
sslkey=CONFIG.y("postgresql.sslkey"),
dbname=CONFIG.get("postgresql.name"),
user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"),
port=int(CONFIG.get("postgresql.port")),
sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"),
sslkey=CONFIG.get("postgresql.sslkey"),
)
cursor = conn.cursor()
cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;")

View file

@ -24,8 +24,8 @@ BASE_DIR = Path(__file__).absolute().parent.parent.parent
STATICFILES_DIRS = [BASE_DIR / Path("web")]
MEDIA_ROOT = BASE_DIR / Path("media")
DEBUG = CONFIG.y_bool("debug")
SECRET_KEY = CONFIG.y("secret_key")
DEBUG = CONFIG.get_bool("debug")
SECRET_KEY = CONFIG.get("secret_key")
INTERNAL_IPS = ["127.0.0.1"]
ALLOWED_HOSTS = ["*"]
@ -40,7 +40,7 @@ CSRF_COOKIE_NAME = "authentik_csrf"
CSRF_HEADER_NAME = "HTTP_X_AUTHENTIK_CSRF"
LANGUAGE_COOKIE_NAME = "authentik_language"
SESSION_COOKIE_NAME = "authentik_session"
SESSION_COOKIE_DOMAIN = CONFIG.y("cookie_domain", None)
SESSION_COOKIE_DOMAIN = CONFIG.get("cookie_domain", None)
AUTHENTICATION_BACKENDS = [
"django.contrib.auth.backends.ModelBackend",
@ -179,26 +179,26 @@ REST_FRAMEWORK = {
"TEST_REQUEST_DEFAULT_FORMAT": "json",
"DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.AnonRateThrottle"],
"DEFAULT_THROTTLE_RATES": {
"anon": CONFIG.y("throttle.default"),
"anon": CONFIG.get("throttle.default"),
},
}
_redis_protocol_prefix = "redis://"
_redis_celery_tls_requirements = ""
if CONFIG.y_bool("redis.tls", False):
if CONFIG.get_bool("redis.tls", False):
_redis_protocol_prefix = "rediss://"
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.y('redis.tls_reqs')}"
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}"
_redis_url = (
f"{_redis_protocol_prefix}:"
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:"
f"{int(CONFIG.y('redis.port'))}"
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{int(CONFIG.get('redis.port'))}"
)
CACHES = {
"default": {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": f"{_redis_url}/{CONFIG.y('redis.db')}",
"TIMEOUT": int(CONFIG.y("redis.cache_timeout", 300)),
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
"TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
"KEY_PREFIX": "authentik_cache",
}
@ -238,7 +238,7 @@ ROOT_URLCONF = "authentik.root.urls"
TEMPLATES = [
{
"BACKEND": "django.template.backends.django.DjangoTemplates",
"DIRS": [CONFIG.y("email.template_dir")],
"DIRS": [CONFIG.get("email.template_dir")],
"APP_DIRS": True,
"OPTIONS": {
"context_processors": [
@ -258,7 +258,7 @@ CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
"hosts": [f"{_redis_url}/{CONFIG.y('redis.db')}"],
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
"prefix": "authentik_channels",
},
},
@ -270,34 +270,37 @@ CHANNEL_LAYERS = {
DATABASES = {
"default": {
"ENGINE": "django_prometheus.db.backends.postgresql",
"HOST": CONFIG.y("postgresql.host"),
"NAME": CONFIG.y("postgresql.name"),
"USER": CONFIG.y("postgresql.user"),
"PASSWORD": CONFIG.y("postgresql.password"),
"PORT": int(CONFIG.y("postgresql.port")),
"SSLMODE": CONFIG.y("postgresql.sslmode"),
"SSLROOTCERT": CONFIG.y("postgresql.sslrootcert"),
"SSLCERT": CONFIG.y("postgresql.sslcert"),
"SSLKEY": CONFIG.y("postgresql.sslkey"),
"ENGINE": "authentik.root.db",
"HOST": CONFIG.get("postgresql.host"),
"NAME": CONFIG.get("postgresql.name"),
"USER": CONFIG.get("postgresql.user"),
"PASSWORD": CONFIG.get("postgresql.password"),
"PORT": int(CONFIG.get("postgresql.port")),
"SSLMODE": CONFIG.get("postgresql.sslmode"),
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
"SSLCERT": CONFIG.get("postgresql.sslcert"),
"SSLKEY": CONFIG.get("postgresql.sslkey"),
}
}
if CONFIG.y_bool("postgresql.use_pgbouncer", False):
if CONFIG.get_bool("postgresql.use_pgbouncer", False):
# https://docs.djangoproject.com/en/4.0/ref/databases/#transaction-pooling-server-side-cursors
DATABASES["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True
# https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections
DATABASES["default"]["CONN_MAX_AGE"] = None # persistent
# Email
EMAIL_HOST = CONFIG.y("email.host")
EMAIL_PORT = int(CONFIG.y("email.port"))
EMAIL_HOST_USER = CONFIG.y("email.username")
EMAIL_HOST_PASSWORD = CONFIG.y("email.password")
EMAIL_USE_TLS = CONFIG.y_bool("email.use_tls", False)
EMAIL_USE_SSL = CONFIG.y_bool("email.use_ssl", False)
EMAIL_TIMEOUT = int(CONFIG.y("email.timeout"))
DEFAULT_FROM_EMAIL = CONFIG.y("email.from")
# These values should never actually be used, emails are only sent from email stages, which
# loads the config directly from CONFIG
# See authentik/stages/email/models.py, line 105
EMAIL_HOST = CONFIG.get("email.host")
EMAIL_PORT = int(CONFIG.get("email.port"))
EMAIL_HOST_USER = CONFIG.get("email.username")
EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
SERVER_EMAIL = DEFAULT_FROM_EMAIL
EMAIL_SUBJECT_PREFIX = "[authentik] "
@ -345,15 +348,15 @@ CELERY = {
},
"task_create_missing_queues": True,
"task_default_queue": "authentik",
"broker_url": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
"result_backend": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
"broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
"result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
}
# Sentry integration
env = get_env()
_ERROR_REPORTING = CONFIG.y_bool("error_reporting.enabled", False)
_ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False)
if _ERROR_REPORTING:
sentry_env = CONFIG.y("error_reporting.environment", "customer")
sentry_env = CONFIG.get("error_reporting.environment", "customer")
sentry_init()
set_tag("authentik.uuid", sha512(str(SECRET_KEY).encode("ascii")).hexdigest()[:16])
@ -367,7 +370,7 @@ MEDIA_URL = "/media/"
TEST = False
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
# We can't check TEST here as its set later by the test runner
LOG_LEVEL = CONFIG.y("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
LOG_LEVEL = CONFIG.get("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
# Additionally, the entire code uses debug as highest level so that would have to be re-written too

View file

@ -31,14 +31,14 @@ class PytestTestRunner: # pragma: no cover
settings.TEST = True
settings.CELERY["task_always_eager"] = True
CONFIG.y_set("avatars", "none")
CONFIG.y_set("geoip", "tests/GeoLite2-City-Test.mmdb")
CONFIG.y_set("blueprints_dir", "./blueprints")
CONFIG.y_set(
CONFIG.set("avatars", "none")
CONFIG.set("geoip", "tests/GeoLite2-City-Test.mmdb")
CONFIG.set("blueprints_dir", "./blueprints")
CONFIG.set(
"outposts.container_image_base",
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
)
CONFIG.y_set("error_reporting.sample_rate", 0)
CONFIG.set("error_reporting.sample_rate", 0)
sentry_init(
environment="testing",
send_default_pii=True,

View file

@ -136,7 +136,7 @@ class LDAPSource(Source):
chmod(private_key_file, 0o600)
tls_kwargs["local_private_key_file"] = private_key_file
tls_kwargs["local_certificate_file"] = certificate_file
if ciphers := CONFIG.y("ldap.tls.ciphers", None):
if ciphers := CONFIG.get("ldap.tls.ciphers", None):
tls_kwargs["ciphers"] = ciphers.strip()
if self.sni:
tls_kwargs["sni"] = self.server_uri.split(",", maxsplit=1)[0].strip()

View file

@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
types_only=False,
get_operational_attributes=False,
controls=None,
paged_size=int(CONFIG.y("ldap.page_size", 50)),
paged_size=int(CONFIG.get("ldap.page_size", 50)),
paged_criticality=False,
):
"""Search in pages, returns each page"""

View file

@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
@CELERY_APP.task(
bind=True,
base=MonitoredTask,
soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
)
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source"""
self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours"))
self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
if not source:
# Because the source couldn't be found, we don't have a UID

View file

@ -13,6 +13,7 @@ from rest_framework.serializers import BaseSerializer
from structlog.stdlib import get_logger
from authentik.flows.models import Stage
from authentik.lib.config import CONFIG
LOGGER = get_logger()
@ -104,7 +105,16 @@ class EmailStage(Stage):
def backend(self) -> BaseEmailBackend:
"""Get fully configured Email Backend instance"""
if self.use_global_settings:
return self.backend_class()
CONFIG.refresh("email.password")
return self.backend_class(
host=CONFIG.get("email.host"),
port=int(CONFIG.get("email.port")),
username=CONFIG.get("email.username"),
password=CONFIG.get("email.password"),
use_tls=CONFIG.get_bool("email.use_tls", False),
use_ssl=CONFIG.get_bool("email.use_ssl", False),
timeout=int(CONFIG.get("email.timeout")),
)
return self.backend_class(
host=self.host,
port=self.port,

View file

@ -13,6 +13,7 @@ from authentik.flows.models import FlowDesignation, FlowStageBinding, FlowToken
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.tests import FlowTestCase
from authentik.flows.views.executor import QS_KEY_TOKEN, SESSION_KEY_PLAN
from authentik.lib.config import CONFIG
from authentik.stages.email.models import EmailStage
from authentik.stages.email.stage import PLAN_CONTEXT_EMAIL_OVERRIDE
@ -120,7 +121,7 @@ class TestEmailStage(FlowTestCase):
def test_use_global_settings(self):
"""Test use_global_settings"""
host = "some-unique-string"
with self.settings(EMAIL_HOST=host):
with CONFIG.patch("email.host", host):
self.assertEqual(EmailStage(use_global_settings=True).backend.host, host)
def test_token(self):

View file

@ -78,7 +78,7 @@ class CurrentTenantSerializer(PassiveSerializer):
ui_footer_links = ListField(
child=FooterLinkSerializer(),
read_only=True,
default=CONFIG.y("footer_links", []),
default=CONFIG.get("footer_links", []),
)
ui_theme = ChoiceField(
choices=Themes.choices,

View file

@ -24,7 +24,7 @@ class TestTenants(APITestCase):
"branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "authentik",
"matched_domain": tenant.domain,
"ui_footer_links": CONFIG.y("footer_links"),
"ui_footer_links": CONFIG.get("footer_links"),
"ui_theme": Themes.AUTOMATIC,
"default_locale": "",
},
@ -43,7 +43,7 @@ class TestTenants(APITestCase):
"branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "custom",
"matched_domain": "bar.baz",
"ui_footer_links": CONFIG.y("footer_links"),
"ui_footer_links": CONFIG.get("footer_links"),
"ui_theme": Themes.AUTOMATIC,
"default_locale": "",
},
@ -59,7 +59,7 @@ class TestTenants(APITestCase):
"branding_favicon": "/static/dist/assets/icons/icon.png",
"branding_title": "authentik",
"matched_domain": "fallback",
"ui_footer_links": CONFIG.y("footer_links"),
"ui_footer_links": CONFIG.get("footer_links"),
"ui_theme": Themes.AUTOMATIC,
"default_locale": "",
},

View file

@ -36,7 +36,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
trace = span.to_traceparent()
return {
"tenant": tenant,
"footer_links": CONFIG.y("footer_links"),
"footer_links": CONFIG.get("footer_links"),
"sentry_trace": trace,
"version": get_full_version(),
}

View file

@ -94,21 +94,21 @@ entries:
prompt_data = request.context.get("prompt_data")
if not request.user.group_attributes(request.http_request).get(
USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.y_bool("default_user_change_email", True)
USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.get_bool("default_user_change_email", True)
):
if prompt_data.get("email") != request.user.email:
ak_message("Not allowed to change email address.")
return False
if not request.user.group_attributes(request.http_request).get(
USER_ATTRIBUTE_CHANGE_NAME, CONFIG.y_bool("default_user_change_name", True)
USER_ATTRIBUTE_CHANGE_NAME, CONFIG.get_bool("default_user_change_name", True)
):
if prompt_data.get("name") != request.user.name:
ak_message("Not allowed to change name.")
return False
if not request.user.group_attributes(request.http_request).get(
USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.y_bool("default_user_change_username", True)
USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.get_bool("default_user_change_username", True)
):
if prompt_data.get("username") != request.user.username:
ak_message("Not allowed to change username.")

View file

@ -37,7 +37,7 @@ makedirs(prometheus_tmp_dir, exist_ok=True)
max_requests = 1000
max_requests_jitter = 50
_debug = CONFIG.y_bool("DEBUG", False)
_debug = CONFIG.get_bool("DEBUG", False)
logconfig_dict = {
"version": 1,
@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ:
else:
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
workers = int(CONFIG.y("web.workers", default_workers))
threads = int(CONFIG.y("web.threads", 4))
workers = int(CONFIG.get("web.workers", default_workers))
threads = int(CONFIG.get("web.threads", 4))
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
@ -133,7 +133,7 @@ def pre_fork(server: "Arbiter", worker: DjangoUvicornWorker):
worker._worker_id = _next_worker_id(server)
if not CONFIG.y_bool("disable_startup_analytics", False):
if not CONFIG.get_bool("disable_startup_analytics", False):
env = get_env()
should_send = env not in ["dev", "ci"]
if should_send:
@ -158,7 +158,7 @@ if not CONFIG.y_bool("disable_startup_analytics", False):
except Exception: # nosec
pass
if CONFIG.y_bool("remote_debug"):
if CONFIG.get_bool("remote_debug"):
import debugpy
debugpy.listen(("0.0.0.0", 6800)) # nosec

View file

@ -52,15 +52,15 @@ def release_lock():
if __name__ == "__main__":
conn = connect(
dbname=CONFIG.y("postgresql.name"),
user=CONFIG.y("postgresql.user"),
password=CONFIG.y("postgresql.password"),
host=CONFIG.y("postgresql.host"),
port=int(CONFIG.y("postgresql.port")),
sslmode=CONFIG.y("postgresql.sslmode"),
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
sslcert=CONFIG.y("postgresql.sslcert"),
sslkey=CONFIG.y("postgresql.sslkey"),
dbname=CONFIG.get("postgresql.name"),
user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"),
port=int(CONFIG.get("postgresql.port")),
sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"),
sslkey=CONFIG.get("postgresql.sslkey"),
)
curr = conn.cursor()
try:

View file

@ -25,7 +25,7 @@ class Migration(BaseMigration):
# If we already have migrations in the database, assume we're upgrading an existing install
# and set the install id to the secret key
self.cur.execute(
"INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.y("secret_key"),)
"INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.get("secret_key"),)
)
else:
# Otherwise assume a new install, generate an install ID based on a UUID

View file

@ -108,14 +108,14 @@ class Migration(BaseMigration):
self.con.commit()
# We also need to clean the cache to make sure no pickeled objects still exist
for db in [
CONFIG.y("redis.message_queue_db"),
CONFIG.y("redis.cache_db"),
CONFIG.y("redis.ws_db"),
CONFIG.get("redis.message_queue_db"),
CONFIG.get("redis.cache_db"),
CONFIG.get("redis.ws_db"),
]:
redis = Redis(
host=CONFIG.y("redis.host"),
host=CONFIG.get("redis.host"),
port=6379,
db=db,
password=CONFIG.y("redis.password"),
password=CONFIG.get("redis.password"),
)
redis.flushall()

View file

@ -14,7 +14,7 @@ from authentik.lib.config import CONFIG
CONFIG.log("info", "Starting authentik bootstrap")
# Sanity check, ensure SECRET_KEY is set before we even check for database connectivity
if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0:
if CONFIG.get("secret_key") is None or len(CONFIG.get("secret_key")) == 0:
CONFIG.log("info", "----------------------------------------------------------------------")
CONFIG.log("info", "Secret key missing, check https://goauthentik.io/docs/installation/.")
CONFIG.log("info", "----------------------------------------------------------------------")
@ -24,15 +24,15 @@ if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0:
while True:
try:
conn = connect(
dbname=CONFIG.y("postgresql.name"),
user=CONFIG.y("postgresql.user"),
password=CONFIG.y("postgresql.password"),
host=CONFIG.y("postgresql.host"),
port=int(CONFIG.y("postgresql.port")),
sslmode=CONFIG.y("postgresql.sslmode"),
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
sslcert=CONFIG.y("postgresql.sslcert"),
sslkey=CONFIG.y("postgresql.sslkey"),
dbname=CONFIG.get("postgresql.name"),
user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"),
port=int(CONFIG.get("postgresql.port")),
sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"),
sslkey=CONFIG.get("postgresql.sslkey"),
)
conn.cursor()
break
@ -42,12 +42,12 @@ while True:
CONFIG.log("info", "PostgreSQL connection successful")
REDIS_PROTOCOL_PREFIX = "redis://"
if CONFIG.y_bool("redis.tls", False):
if CONFIG.get_bool("redis.tls", False):
REDIS_PROTOCOL_PREFIX = "rediss://"
REDIS_URL = (
f"{REDIS_PROTOCOL_PREFIX}:"
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:"
f"{int(CONFIG.y('redis.port'))}/{CONFIG.y('redis.db')}"
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}"
)
while True:
try:

View file

@ -1,7 +1,6 @@
"""Test Enroll flow"""
from time import sleep
from django.test import override_settings
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as ec
from selenium.webdriver.support.wait import WebDriverWait
@ -9,6 +8,7 @@ from selenium.webdriver.support.wait import WebDriverWait
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import User
from authentik.flows.models import Flow
from authentik.lib.config import CONFIG
from authentik.stages.identification.models import IdentificationStage
from tests.e2e.utils import SeleniumTestCase, retry
@ -56,7 +56,7 @@ class TestFlowsEnroll(SeleniumTestCase):
@apply_blueprint(
"example/flows-enrollment-email-verification.yaml",
)
@override_settings(EMAIL_PORT=1025)
@CONFIG.patch("email.port", 1025)
def test_enroll_email(self):
"""Test enroll with Email verification"""
# Attach enrollment flow to identification stage

View file

@ -1,7 +1,6 @@
"""Test recovery flow"""
from time import sleep
from django.test import override_settings
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as ec
from selenium.webdriver.support.wait import WebDriverWait
@ -10,6 +9,7 @@ from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import User
from authentik.core.tests.utils import create_test_admin_user
from authentik.flows.models import Flow
from authentik.lib.config import CONFIG
from authentik.lib.generators import generate_id
from authentik.stages.identification.models import IdentificationStage
from tests.e2e.utils import SeleniumTestCase, retry
@ -47,7 +47,7 @@ class TestFlowsRecovery(SeleniumTestCase):
@apply_blueprint(
"example/flows-recovery-email-verification.yaml",
)
@override_settings(EMAIL_PORT=1025)
@CONFIG.patch("email.port", 1025)
def test_recover_email(self):
"""Test recovery with Email verification"""
# Attach recovery flow to identification stage