root: partial Live-updating config (#5959)
* stages/email: directly use email credentials from config Signed-off-by: Jens Langhammer <jens@goauthentik.io> * use custom database backend that supports dynamic credentials Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> * add crude config reloader Signed-off-by: Jens Langhammer <jens@goauthentik.io> * make method names for CONFIG clearer Signed-off-by: Jens Langhammer <jens@goauthentik.io> * replace config.set with environ Not sure if this is the cleanest way, but it persists through a config reload Signed-off-by: Jens Langhammer <jens@goauthentik.io> * re-add set for @patch Signed-off-by: Jens Langhammer <jens@goauthentik.io> * even more crudeness Signed-off-by: Jens Langhammer <jens@goauthentik.io> * clean up some old stuff? Signed-off-by: Jens Langhammer <jens@goauthentik.io> * somewhat rewrite config loader to keep track of a source of an attribute so we can refresh it Signed-off-by: Jens Langhammer <jens@goauthentik.io> * cleanup old things Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix flow e2e Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
fb4e4dc8db
commit
2f469d2709
|
@ -58,7 +58,7 @@ def clear_update_notifications():
|
|||
@prefill_task
|
||||
def update_latest_version(self: MonitoredTask):
|
||||
"""Update latest version info"""
|
||||
if CONFIG.y_bool("disable_update_check"):
|
||||
if CONFIG.get_bool("disable_update_check"):
|
||||
cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT)
|
||||
self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."]))
|
||||
return
|
||||
|
|
|
@ -70,7 +70,7 @@ class ConfigView(APIView):
|
|||
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
||||
if GEOIP_READER.enabled:
|
||||
caps.append(Capabilities.CAN_GEO_IP)
|
||||
if CONFIG.y_bool("impersonation"):
|
||||
if CONFIG.get_bool("impersonation"):
|
||||
caps.append(Capabilities.CAN_IMPERSONATE)
|
||||
if settings.DEBUG: # pragma: no cover
|
||||
caps.append(Capabilities.CAN_DEBUG)
|
||||
|
@ -86,17 +86,17 @@ class ConfigView(APIView):
|
|||
return ConfigSerializer(
|
||||
{
|
||||
"error_reporting": {
|
||||
"enabled": CONFIG.y("error_reporting.enabled"),
|
||||
"sentry_dsn": CONFIG.y("error_reporting.sentry_dsn"),
|
||||
"environment": CONFIG.y("error_reporting.environment"),
|
||||
"send_pii": CONFIG.y("error_reporting.send_pii"),
|
||||
"traces_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.4)),
|
||||
"enabled": CONFIG.get("error_reporting.enabled"),
|
||||
"sentry_dsn": CONFIG.get("error_reporting.sentry_dsn"),
|
||||
"environment": CONFIG.get("error_reporting.environment"),
|
||||
"send_pii": CONFIG.get("error_reporting.send_pii"),
|
||||
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
|
||||
},
|
||||
"capabilities": self.get_capabilities(),
|
||||
"cache_timeout": int(CONFIG.y("redis.cache_timeout")),
|
||||
"cache_timeout_flows": int(CONFIG.y("redis.cache_timeout_flows")),
|
||||
"cache_timeout_policies": int(CONFIG.y("redis.cache_timeout_policies")),
|
||||
"cache_timeout_reputation": int(CONFIG.y("redis.cache_timeout_reputation")),
|
||||
"cache_timeout": int(CONFIG.get("redis.cache_timeout")),
|
||||
"cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
|
||||
"cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
|
||||
"cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path):
|
|||
return
|
||||
blueprint_file.seek(0)
|
||||
instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first()
|
||||
rel_path = path.relative_to(Path(CONFIG.y("blueprints_dir")))
|
||||
rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir")))
|
||||
meta = None
|
||||
if metadata:
|
||||
meta = from_dict(BlueprintMetadata, metadata)
|
||||
|
@ -55,7 +55,7 @@ def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
|
|||
Flow = apps.get_model("authentik_flows", "Flow")
|
||||
|
||||
db_alias = schema_editor.connection.alias
|
||||
for file in glob(f"{CONFIG.y('blueprints_dir')}/**/*.yaml", recursive=True):
|
||||
for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True):
|
||||
check_blueprint_v1_file(BlueprintInstance, Path(file))
|
||||
|
||||
for blueprint in BlueprintInstance.objects.using(db_alias).all():
|
||||
|
|
|
@ -82,7 +82,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
|
|||
def retrieve_file(self) -> str:
|
||||
"""Get blueprint from path"""
|
||||
try:
|
||||
base = Path(CONFIG.y("blueprints_dir"))
|
||||
base = Path(CONFIG.get("blueprints_dir"))
|
||||
full_path = base.joinpath(Path(self.path)).resolve()
|
||||
if not str(full_path).startswith(str(base.resolve())):
|
||||
raise BlueprintRetrievalFailed("Invalid blueprint path")
|
||||
|
|
|
@ -62,7 +62,7 @@ def start_blueprint_watcher():
|
|||
if _file_watcher_started:
|
||||
return
|
||||
observer = Observer()
|
||||
observer.schedule(BlueprintEventHandler(), CONFIG.y("blueprints_dir"), recursive=True)
|
||||
observer.schedule(BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True)
|
||||
observer.start()
|
||||
_file_watcher_started = True
|
||||
|
||||
|
@ -80,7 +80,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
|
|||
blueprints_discovery.delay()
|
||||
if isinstance(event, FileModifiedEvent):
|
||||
path = Path(event.src_path)
|
||||
root = Path(CONFIG.y("blueprints_dir")).absolute()
|
||||
root = Path(CONFIG.get("blueprints_dir")).absolute()
|
||||
rel_path = str(path.relative_to(root))
|
||||
for instance in BlueprintInstance.objects.filter(path=rel_path):
|
||||
LOGGER.debug("modified blueprint file, starting apply", instance=instance)
|
||||
|
@ -101,7 +101,7 @@ def blueprints_find_dict():
|
|||
def blueprints_find():
|
||||
"""Find blueprints and return valid ones"""
|
||||
blueprints = []
|
||||
root = Path(CONFIG.y("blueprints_dir"))
|
||||
root = Path(CONFIG.get("blueprints_dir"))
|
||||
for path in root.rglob("**/*.yaml"):
|
||||
# Check if any part in the path starts with a dot and assume a hidden file
|
||||
if any(part for part in path.parts if part.startswith(".")):
|
||||
|
|
|
@ -596,7 +596,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
|||
@action(detail=True, methods=["POST"])
|
||||
def impersonate(self, request: Request, pk: int) -> Response:
|
||||
"""Impersonate a user"""
|
||||
if not CONFIG.y_bool("impersonation"):
|
||||
if not CONFIG.get_bool("impersonation"):
|
||||
LOGGER.debug("User attempted to impersonate", user=request.user)
|
||||
return Response(status=401)
|
||||
if not request.user.has_perm("impersonate"):
|
||||
|
|
|
@ -18,7 +18,7 @@ class Command(BaseCommand):
|
|||
|
||||
def handle(self, **options):
|
||||
close_old_connections()
|
||||
if CONFIG.y_bool("remote_debug"):
|
||||
if CONFIG.get_bool("remote_debug"):
|
||||
import debugpy
|
||||
|
||||
debugpy.listen(("0.0.0.0", 6900)) # nosec
|
||||
|
|
|
@ -60,7 +60,7 @@ def default_token_key():
|
|||
"""Default token key"""
|
||||
# We use generate_id since the chars in the key should be easy
|
||||
# to use in Emails (for verification) and URLs (for recovery)
|
||||
return generate_id(int(CONFIG.y("default_token_length")))
|
||||
return generate_id(int(CONFIG.get("default_token_length")))
|
||||
|
||||
|
||||
class UserTypes(models.TextChoices):
|
||||
|
|
|
@ -46,7 +46,7 @@ def certificate_discovery(self: MonitoredTask):
|
|||
certs = {}
|
||||
private_keys = {}
|
||||
discovered = 0
|
||||
for file in glob(CONFIG.y("cert_discovery_dir") + "/**", recursive=True):
|
||||
for file in glob(CONFIG.get("cert_discovery_dir") + "/**", recursive=True):
|
||||
path = Path(file)
|
||||
if not path.exists():
|
||||
continue
|
||||
|
|
|
@ -33,7 +33,7 @@ class GeoIPReader:
|
|||
|
||||
def __open(self):
|
||||
"""Get GeoIP Reader, if configured, otherwise none"""
|
||||
path = CONFIG.y("geoip")
|
||||
path = CONFIG.get("geoip")
|
||||
if path == "" or not path:
|
||||
return
|
||||
try:
|
||||
|
@ -46,7 +46,7 @@ class GeoIPReader:
|
|||
def __check_expired(self):
|
||||
"""Check if the modification date of the GeoIP database has
|
||||
changed, and reload it if so"""
|
||||
path = CONFIG.y("geoip")
|
||||
path = CONFIG.get("geoip")
|
||||
try:
|
||||
mtime = stat(path).st_mtime
|
||||
diff = self.__last_mtime < mtime
|
||||
|
|
|
@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
|
|||
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
|
||||
# was restored.
|
||||
PLAN_CONTEXT_IS_RESTORED = "is_restored"
|
||||
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_flows"))
|
||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
|
||||
CACHE_PREFIX = "goauthentik.io/flows/planner/"
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ from authentik.flows.planner import FlowPlan, FlowPlanner
|
|||
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
|
@ -85,7 +84,6 @@ class TestFlowExecutor(FlowTestCase):
|
|||
FlowDesignation.AUTHENTICATION,
|
||||
)
|
||||
|
||||
CONFIG.update_from_dict({"domain": "testserver"})
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
|
@ -111,7 +109,6 @@ class TestFlowExecutor(FlowTestCase):
|
|||
denied_action=FlowDeniedAction.CONTINUE,
|
||||
)
|
||||
|
||||
CONFIG.update_from_dict({"domain": "testserver"})
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
|
@ -128,7 +125,6 @@ class TestFlowExecutor(FlowTestCase):
|
|||
FlowDesignation.AUTHENTICATION,
|
||||
)
|
||||
|
||||
CONFIG.update_from_dict({"domain": "testserver"})
|
||||
dest = "/unique-string"
|
||||
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||
response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}")
|
||||
|
@ -145,7 +141,6 @@ class TestFlowExecutor(FlowTestCase):
|
|||
FlowDesignation.AUTHENTICATION,
|
||||
)
|
||||
|
||||
CONFIG.update_from_dict({"domain": "testserver"})
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
)
|
||||
|
|
|
@ -175,7 +175,7 @@ def get_avatar(user: "User") -> str:
|
|||
"initials": avatar_mode_generated,
|
||||
"gravatar": avatar_mode_gravatar,
|
||||
}
|
||||
modes: str = CONFIG.y("avatars", "none")
|
||||
modes: str = CONFIG.get("avatars", "none")
|
||||
for mode in modes.split(","):
|
||||
avatar = None
|
||||
if mode in mode_map:
|
||||
|
|
|
@ -2,13 +2,15 @@
|
|||
import os
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from glob import glob
|
||||
from json import dumps, loads
|
||||
from json import JSONEncoder, dumps, loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from sys import argv, stderr
|
||||
from time import time
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml
|
||||
|
@ -32,15 +34,44 @@ def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
|
|||
return root
|
||||
|
||||
|
||||
@dataclass
|
||||
class Attr:
|
||||
"""Single configuration attribute"""
|
||||
|
||||
class Source(Enum):
|
||||
"""Sources a configuration attribute can come from, determines what should be done with
|
||||
Attr.source (and if it's set at all)"""
|
||||
|
||||
UNSPECIFIED = "unspecified"
|
||||
ENV = "env"
|
||||
CONFIG_FILE = "config_file"
|
||||
URI = "uri"
|
||||
|
||||
value: Any
|
||||
|
||||
source_type: Source = field(default=Source.UNSPECIFIED)
|
||||
|
||||
# depending on source_type, might contain the environment variable or the path
|
||||
# to the config file containing this change or the file containing this value
|
||||
source: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
class AttrEncoder(JSONEncoder):
|
||||
"""JSON encoder that can deal with `Attr` classes"""
|
||||
|
||||
def default(self, o: Any) -> Any:
|
||||
if isinstance(o, Attr):
|
||||
return o.value
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""Search through SEARCH_PATHS and load configuration. Environment variables starting with
|
||||
`ENV_PREFIX` are also applied.
|
||||
|
||||
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
|
||||
|
||||
loaded_file = []
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.__config = {}
|
||||
base_dir = Path(__file__).parent.joinpath(Path("../..")).resolve()
|
||||
|
@ -65,6 +96,7 @@ class ConfigLoader:
|
|||
# Update config with env file
|
||||
self.update_from_file(env_file)
|
||||
self.update_from_env()
|
||||
self.update(self.__config, kwargs)
|
||||
|
||||
def log(self, level: str, message: str, **kwargs):
|
||||
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when
|
||||
|
@ -86,22 +118,32 @@ class ConfigLoader:
|
|||
else:
|
||||
if isinstance(value, str):
|
||||
value = self.parse_uri(value)
|
||||
elif not isinstance(value, Attr):
|
||||
value = Attr(value)
|
||||
root[key] = value
|
||||
return root
|
||||
|
||||
def parse_uri(self, value: str) -> str:
|
||||
def refresh(self, key: str):
|
||||
"""Update a single value"""
|
||||
attr: Attr = get_path_from_dict(self.raw, key)
|
||||
if attr.source_type != Attr.Source.URI:
|
||||
return
|
||||
attr.value = self.parse_uri(attr.source).value
|
||||
|
||||
def parse_uri(self, value: str) -> Attr:
|
||||
"""Parse string values which start with a URI"""
|
||||
url = urlparse(value)
|
||||
parsed_value = value
|
||||
if url.scheme == "env":
|
||||
value = os.getenv(url.netloc, url.query)
|
||||
parsed_value = os.getenv(url.netloc, url.query)
|
||||
if url.scheme == "file":
|
||||
try:
|
||||
with open(url.path, "r", encoding="utf8") as _file:
|
||||
value = _file.read().strip()
|
||||
parsed_value = _file.read().strip()
|
||||
except OSError as exc:
|
||||
self.log("error", f"Failed to read config value from {url.path}: {exc}")
|
||||
value = url.query
|
||||
return value
|
||||
parsed_value = url.query
|
||||
return Attr(parsed_value, Attr.Source.URI, value)
|
||||
|
||||
def update_from_file(self, path: Path):
|
||||
"""Update config from file contents"""
|
||||
|
@ -110,7 +152,6 @@ class ConfigLoader:
|
|||
try:
|
||||
self.update(self.__config, yaml.safe_load(file))
|
||||
self.log("debug", "Loaded config", file=str(path))
|
||||
self.loaded_file.append(path)
|
||||
except yaml.YAMLError as exc:
|
||||
raise ImproperlyConfigured from exc
|
||||
except PermissionError as exc:
|
||||
|
@ -121,10 +162,6 @@ class ConfigLoader:
|
|||
error=str(exc),
|
||||
)
|
||||
|
||||
def update_from_dict(self, update: dict):
|
||||
"""Update config from dict"""
|
||||
self.__config.update(update)
|
||||
|
||||
def update_from_env(self):
|
||||
"""Check environment variables"""
|
||||
outer = {}
|
||||
|
@ -145,7 +182,7 @@ class ConfigLoader:
|
|||
value = loads(value)
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
current_obj[dot_parts[-1]] = value
|
||||
current_obj[dot_parts[-1]] = Attr(value, Attr.Source.ENV, key)
|
||||
idx += 1
|
||||
if idx > 0:
|
||||
self.log("debug", "Loaded environment variables", count=idx)
|
||||
|
@ -154,28 +191,32 @@ class ConfigLoader:
|
|||
@contextmanager
|
||||
def patch(self, path: str, value: Any):
|
||||
"""Context manager for unittests to patch a value"""
|
||||
original_value = self.y(path)
|
||||
self.y_set(path, value)
|
||||
original_value = self.get(path)
|
||||
self.set(path, value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.y_set(path, original_value)
|
||||
self.set(path, original_value)
|
||||
|
||||
@property
|
||||
def raw(self) -> dict:
|
||||
"""Get raw config dictionary"""
|
||||
return self.__config
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def y(self, path: str, default=None, sep=".") -> Any:
|
||||
def get(self, path: str, default=None, sep=".") -> Any:
|
||||
"""Access attribute by using yaml path"""
|
||||
# Walk sub_dicts before parsing path
|
||||
root = self.raw
|
||||
# Walk each component of the path
|
||||
return get_path_from_dict(root, path, sep=sep, default=default)
|
||||
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
|
||||
return attr.value
|
||||
|
||||
def y_set(self, path: str, value: Any, sep="."):
|
||||
"""Set value using same syntax as y()"""
|
||||
def get_bool(self, path: str, default=False) -> bool:
|
||||
"""Wrapper for get that converts value into boolean"""
|
||||
return str(self.get(path, default)).lower() == "true"
|
||||
|
||||
def set(self, path: str, value: Any, sep="."):
|
||||
"""Set value using same syntax as get()"""
|
||||
# Walk sub_dicts before parsing path
|
||||
root = self.raw
|
||||
# Walk each component of the path
|
||||
|
@ -184,17 +225,14 @@ class ConfigLoader:
|
|||
if comp not in root:
|
||||
root[comp] = {}
|
||||
root = root.get(comp, {})
|
||||
root[path_parts[-1]] = value
|
||||
|
||||
def y_bool(self, path: str, default=False) -> bool:
|
||||
"""Wrapper for y that converts value into boolean"""
|
||||
return str(self.y(path, default)).lower() == "true"
|
||||
root[path_parts[-1]] = Attr(value)
|
||||
|
||||
|
||||
CONFIG = ConfigLoader()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(argv) < 2:
|
||||
print(dumps(CONFIG.raw, indent=4))
|
||||
print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder))
|
||||
else:
|
||||
print(CONFIG.y(argv[1]))
|
||||
print(CONFIG.get(argv[1]))
|
||||
|
|
|
@ -51,18 +51,18 @@ class SentryTransport(HttpTransport):
|
|||
|
||||
def sentry_init(**sentry_init_kwargs):
|
||||
"""Configure sentry SDK"""
|
||||
sentry_env = CONFIG.y("error_reporting.environment", "customer")
|
||||
sentry_env = CONFIG.get("error_reporting.environment", "customer")
|
||||
kwargs = {
|
||||
"environment": sentry_env,
|
||||
"send_default_pii": CONFIG.y_bool("error_reporting.send_pii", False),
|
||||
"send_default_pii": CONFIG.get_bool("error_reporting.send_pii", False),
|
||||
"_experiments": {
|
||||
"profiles_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.1)),
|
||||
"profiles_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.1)),
|
||||
},
|
||||
}
|
||||
kwargs.update(**sentry_init_kwargs)
|
||||
# pylint: disable=abstract-class-instantiated
|
||||
sentry_sdk_init(
|
||||
dsn=CONFIG.y("error_reporting.sentry_dsn"),
|
||||
dsn=CONFIG.get("error_reporting.sentry_dsn"),
|
||||
integrations=[
|
||||
ArgvIntegration(),
|
||||
StdlibIntegration(),
|
||||
|
@ -92,7 +92,7 @@ def traces_sampler(sampling_context: dict) -> float:
|
|||
return 0
|
||||
if _type == "websocket":
|
||||
return 0
|
||||
return float(CONFIG.y("error_reporting.sample_rate", 0.1))
|
||||
return float(CONFIG.get("error_reporting.sample_rate", 0.1))
|
||||
|
||||
|
||||
def before_send(event: dict, hint: dict) -> Optional[dict]:
|
||||
|
|
|
@ -16,23 +16,23 @@ class TestConfig(TestCase):
|
|||
config = ConfigLoader()
|
||||
environ[ENV_PREFIX + "_test__test"] = "bar"
|
||||
config.update_from_env()
|
||||
self.assertEqual(config.y("test.test"), "bar")
|
||||
self.assertEqual(config.get("test.test"), "bar")
|
||||
|
||||
def test_patch(self):
|
||||
"""Test patch decorator"""
|
||||
config = ConfigLoader()
|
||||
config.y_set("foo.bar", "bar")
|
||||
self.assertEqual(config.y("foo.bar"), "bar")
|
||||
config.set("foo.bar", "bar")
|
||||
self.assertEqual(config.get("foo.bar"), "bar")
|
||||
with config.patch("foo.bar", "baz"):
|
||||
self.assertEqual(config.y("foo.bar"), "baz")
|
||||
self.assertEqual(config.y("foo.bar"), "bar")
|
||||
self.assertEqual(config.get("foo.bar"), "baz")
|
||||
self.assertEqual(config.get("foo.bar"), "bar")
|
||||
|
||||
def test_uri_env(self):
|
||||
"""Test URI parsing (environment)"""
|
||||
config = ConfigLoader()
|
||||
environ["foo"] = "bar"
|
||||
self.assertEqual(config.parse_uri("env://foo"), "bar")
|
||||
self.assertEqual(config.parse_uri("env://foo?bar"), "bar")
|
||||
self.assertEqual(config.parse_uri("env://foo").value, "bar")
|
||||
self.assertEqual(config.parse_uri("env://foo?bar").value, "bar")
|
||||
|
||||
def test_uri_file(self):
|
||||
"""Test URI parsing (file load)"""
|
||||
|
@ -41,11 +41,25 @@ class TestConfig(TestCase):
|
|||
write(file, "foo".encode())
|
||||
_, file2_name = mkstemp()
|
||||
chmod(file2_name, 0o000) # Remove all permissions so we can't read the file
|
||||
self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo")
|
||||
self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def")
|
||||
self.assertEqual(config.parse_uri(f"file://{file_name}").value, "foo")
|
||||
self.assertEqual(config.parse_uri(f"file://{file2_name}?def").value, "def")
|
||||
unlink(file_name)
|
||||
unlink(file2_name)
|
||||
|
||||
def test_uri_file_update(self):
|
||||
"""Test URI parsing (file load and update)"""
|
||||
file, file_name = mkstemp()
|
||||
write(file, "foo".encode())
|
||||
config = ConfigLoader(file_test=f"file://{file_name}")
|
||||
self.assertEqual(config.get("file_test"), "foo")
|
||||
|
||||
# Update config file
|
||||
write(file, "bar".encode())
|
||||
config.refresh("file_test")
|
||||
self.assertEqual(config.get("file_test"), "foobar")
|
||||
|
||||
unlink(file_name)
|
||||
|
||||
def test_file_update(self):
|
||||
"""Test update_from_file"""
|
||||
config = ConfigLoader()
|
||||
|
|
|
@ -50,7 +50,7 @@ def get_env() -> str:
|
|||
"""Get environment in which authentik is currently running"""
|
||||
if "CI" in os.environ:
|
||||
return "ci"
|
||||
if CONFIG.y_bool("debug"):
|
||||
if CONFIG.get_bool("debug"):
|
||||
return "dev"
|
||||
if SERVICE_HOST_ENV_NAME in os.environ:
|
||||
return "kubernetes"
|
||||
|
|
|
@ -97,7 +97,7 @@ class BaseController:
|
|||
if self.outpost.config.container_image is not None:
|
||||
return self.outpost.config.container_image
|
||||
|
||||
image_name_template: str = CONFIG.y("outposts.container_image_base")
|
||||
image_name_template: str = CONFIG.get("outposts.container_image_base")
|
||||
return image_name_template % {
|
||||
"type": self.outpost.type,
|
||||
"version": __version__,
|
||||
|
|
|
@ -58,7 +58,7 @@ class OutpostConfig:
|
|||
authentik_host_insecure: bool = False
|
||||
authentik_host_browser: str = ""
|
||||
|
||||
log_level: str = CONFIG.y("log_level")
|
||||
log_level: str = CONFIG.get("log_level")
|
||||
object_naming_template: str = field(default="ak-outpost-%(name)s")
|
||||
|
||||
container_image: Optional[str] = field(default=None)
|
||||
|
|
|
@ -256,7 +256,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
|
|||
def outpost_connection_discovery(self: MonitoredTask):
|
||||
"""Checks the local environment and create Service connections."""
|
||||
status = TaskResult(TaskResultStatus.SUCCESSFUL)
|
||||
if not CONFIG.y_bool("outposts.discover"):
|
||||
if not CONFIG.get_bool("outposts.discover"):
|
||||
status.messages.append("Outpost integration discovery is disabled")
|
||||
self.set_status(status)
|
||||
return
|
||||
|
|
|
@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
|
|||
LOGGER = get_logger()
|
||||
|
||||
FORK_CTX = get_context("fork")
|
||||
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_policies"))
|
||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
|
||||
PROCESS_CLASS = FORK_CTX.Process
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
|
|||
from authentik.stages.identification.signals import identification_failed
|
||||
|
||||
LOGGER = get_logger()
|
||||
CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_reputation"))
|
||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
|
||||
|
||||
|
||||
def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||
|
|
|
@ -46,7 +46,7 @@ class DeviceView(View):
|
|||
|
||||
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
throttle = AnonRateThrottle()
|
||||
throttle.rate = CONFIG.y("throttle.providers.oauth2.device", "20/hour")
|
||||
throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
|
||||
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
|
||||
if not throttle.allow_request(request, self):
|
||||
return HttpResponse(status=429)
|
||||
|
|
0
authentik/root/db/__init__.py
Normal file
0
authentik/root/db/__init__.py
Normal file
15
authentik/root/db/base.py
Normal file
15
authentik/root/db/base.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
"""authentik database backend"""
|
||||
from django_prometheus.db.backends.postgresql.base import DatabaseWrapper as BaseDatabaseWrapper
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
"""database backend which supports rotating credentials"""
|
||||
|
||||
def get_connection_params(self):
|
||||
CONFIG.refresh("postgresql.password")
|
||||
conn_params = super().get_connection_params()
|
||||
conn_params["user"] = CONFIG.get("postgresql.user")
|
||||
conn_params["password"] = CONFIG.get("postgresql.password")
|
||||
return conn_params
|
|
@ -26,15 +26,15 @@ def get_install_id_raw():
|
|||
"""Get install_id without django loaded, this is required for the startup when we get
|
||||
the install_id but django isn't loaded yet and we can't use the function above."""
|
||||
conn = connect(
|
||||
dbname=CONFIG.y("postgresql.name"),
|
||||
user=CONFIG.y("postgresql.user"),
|
||||
password=CONFIG.y("postgresql.password"),
|
||||
host=CONFIG.y("postgresql.host"),
|
||||
port=int(CONFIG.y("postgresql.port")),
|
||||
sslmode=CONFIG.y("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.y("postgresql.sslcert"),
|
||||
sslkey=CONFIG.y("postgresql.sslkey"),
|
||||
dbname=CONFIG.get("postgresql.name"),
|
||||
user=CONFIG.get("postgresql.user"),
|
||||
password=CONFIG.get("postgresql.password"),
|
||||
host=CONFIG.get("postgresql.host"),
|
||||
port=int(CONFIG.get("postgresql.port")),
|
||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||
sslkey=CONFIG.get("postgresql.sslkey"),
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;")
|
||||
|
|
|
@ -24,8 +24,8 @@ BASE_DIR = Path(__file__).absolute().parent.parent.parent
|
|||
STATICFILES_DIRS = [BASE_DIR / Path("web")]
|
||||
MEDIA_ROOT = BASE_DIR / Path("media")
|
||||
|
||||
DEBUG = CONFIG.y_bool("debug")
|
||||
SECRET_KEY = CONFIG.y("secret_key")
|
||||
DEBUG = CONFIG.get_bool("debug")
|
||||
SECRET_KEY = CONFIG.get("secret_key")
|
||||
|
||||
INTERNAL_IPS = ["127.0.0.1"]
|
||||
ALLOWED_HOSTS = ["*"]
|
||||
|
@ -40,7 +40,7 @@ CSRF_COOKIE_NAME = "authentik_csrf"
|
|||
CSRF_HEADER_NAME = "HTTP_X_AUTHENTIK_CSRF"
|
||||
LANGUAGE_COOKIE_NAME = "authentik_language"
|
||||
SESSION_COOKIE_NAME = "authentik_session"
|
||||
SESSION_COOKIE_DOMAIN = CONFIG.y("cookie_domain", None)
|
||||
SESSION_COOKIE_DOMAIN = CONFIG.get("cookie_domain", None)
|
||||
|
||||
AUTHENTICATION_BACKENDS = [
|
||||
"django.contrib.auth.backends.ModelBackend",
|
||||
|
@ -179,26 +179,26 @@ REST_FRAMEWORK = {
|
|||
"TEST_REQUEST_DEFAULT_FORMAT": "json",
|
||||
"DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.AnonRateThrottle"],
|
||||
"DEFAULT_THROTTLE_RATES": {
|
||||
"anon": CONFIG.y("throttle.default"),
|
||||
"anon": CONFIG.get("throttle.default"),
|
||||
},
|
||||
}
|
||||
|
||||
_redis_protocol_prefix = "redis://"
|
||||
_redis_celery_tls_requirements = ""
|
||||
if CONFIG.y_bool("redis.tls", False):
|
||||
if CONFIG.get_bool("redis.tls", False):
|
||||
_redis_protocol_prefix = "rediss://"
|
||||
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.y('redis.tls_reqs')}"
|
||||
_redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}"
|
||||
_redis_url = (
|
||||
f"{_redis_protocol_prefix}:"
|
||||
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:"
|
||||
f"{int(CONFIG.y('redis.port'))}"
|
||||
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
||||
f"{int(CONFIG.get('redis.port'))}"
|
||||
)
|
||||
|
||||
CACHES = {
|
||||
"default": {
|
||||
"BACKEND": "django_redis.cache.RedisCache",
|
||||
"LOCATION": f"{_redis_url}/{CONFIG.y('redis.db')}",
|
||||
"TIMEOUT": int(CONFIG.y("redis.cache_timeout", 300)),
|
||||
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
|
||||
"TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
|
||||
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
|
||||
"KEY_PREFIX": "authentik_cache",
|
||||
}
|
||||
|
@ -238,7 +238,7 @@ ROOT_URLCONF = "authentik.root.urls"
|
|||
TEMPLATES = [
|
||||
{
|
||||
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||
"DIRS": [CONFIG.y("email.template_dir")],
|
||||
"DIRS": [CONFIG.get("email.template_dir")],
|
||||
"APP_DIRS": True,
|
||||
"OPTIONS": {
|
||||
"context_processors": [
|
||||
|
@ -258,7 +258,7 @@ CHANNEL_LAYERS = {
|
|||
"default": {
|
||||
"BACKEND": "channels_redis.core.RedisChannelLayer",
|
||||
"CONFIG": {
|
||||
"hosts": [f"{_redis_url}/{CONFIG.y('redis.db')}"],
|
||||
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
|
||||
"prefix": "authentik_channels",
|
||||
},
|
||||
},
|
||||
|
@ -270,34 +270,37 @@ CHANNEL_LAYERS = {
|
|||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "django_prometheus.db.backends.postgresql",
|
||||
"HOST": CONFIG.y("postgresql.host"),
|
||||
"NAME": CONFIG.y("postgresql.name"),
|
||||
"USER": CONFIG.y("postgresql.user"),
|
||||
"PASSWORD": CONFIG.y("postgresql.password"),
|
||||
"PORT": int(CONFIG.y("postgresql.port")),
|
||||
"SSLMODE": CONFIG.y("postgresql.sslmode"),
|
||||
"SSLROOTCERT": CONFIG.y("postgresql.sslrootcert"),
|
||||
"SSLCERT": CONFIG.y("postgresql.sslcert"),
|
||||
"SSLKEY": CONFIG.y("postgresql.sslkey"),
|
||||
"ENGINE": "authentik.root.db",
|
||||
"HOST": CONFIG.get("postgresql.host"),
|
||||
"NAME": CONFIG.get("postgresql.name"),
|
||||
"USER": CONFIG.get("postgresql.user"),
|
||||
"PASSWORD": CONFIG.get("postgresql.password"),
|
||||
"PORT": int(CONFIG.get("postgresql.port")),
|
||||
"SSLMODE": CONFIG.get("postgresql.sslmode"),
|
||||
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
|
||||
"SSLCERT": CONFIG.get("postgresql.sslcert"),
|
||||
"SSLKEY": CONFIG.get("postgresql.sslkey"),
|
||||
}
|
||||
}
|
||||
|
||||
if CONFIG.y_bool("postgresql.use_pgbouncer", False):
|
||||
if CONFIG.get_bool("postgresql.use_pgbouncer", False):
|
||||
# https://docs.djangoproject.com/en/4.0/ref/databases/#transaction-pooling-server-side-cursors
|
||||
DATABASES["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True
|
||||
# https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections
|
||||
DATABASES["default"]["CONN_MAX_AGE"] = None # persistent
|
||||
|
||||
# Email
|
||||
EMAIL_HOST = CONFIG.y("email.host")
|
||||
EMAIL_PORT = int(CONFIG.y("email.port"))
|
||||
EMAIL_HOST_USER = CONFIG.y("email.username")
|
||||
EMAIL_HOST_PASSWORD = CONFIG.y("email.password")
|
||||
EMAIL_USE_TLS = CONFIG.y_bool("email.use_tls", False)
|
||||
EMAIL_USE_SSL = CONFIG.y_bool("email.use_ssl", False)
|
||||
EMAIL_TIMEOUT = int(CONFIG.y("email.timeout"))
|
||||
DEFAULT_FROM_EMAIL = CONFIG.y("email.from")
|
||||
# These values should never actually be used, emails are only sent from email stages, which
|
||||
# loads the config directly from CONFIG
|
||||
# See authentik/stages/email/models.py, line 105
|
||||
EMAIL_HOST = CONFIG.get("email.host")
|
||||
EMAIL_PORT = int(CONFIG.get("email.port"))
|
||||
EMAIL_HOST_USER = CONFIG.get("email.username")
|
||||
EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
|
||||
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
|
||||
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
|
||||
EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
|
||||
DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
|
||||
SERVER_EMAIL = DEFAULT_FROM_EMAIL
|
||||
EMAIL_SUBJECT_PREFIX = "[authentik] "
|
||||
|
||||
|
@ -345,15 +348,15 @@ CELERY = {
|
|||
},
|
||||
"task_create_missing_queues": True,
|
||||
"task_default_queue": "authentik",
|
||||
"broker_url": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
|
||||
"result_backend": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}",
|
||||
"broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
|
||||
"result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
|
||||
}
|
||||
|
||||
# Sentry integration
|
||||
env = get_env()
|
||||
_ERROR_REPORTING = CONFIG.y_bool("error_reporting.enabled", False)
|
||||
_ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False)
|
||||
if _ERROR_REPORTING:
|
||||
sentry_env = CONFIG.y("error_reporting.environment", "customer")
|
||||
sentry_env = CONFIG.get("error_reporting.environment", "customer")
|
||||
sentry_init()
|
||||
set_tag("authentik.uuid", sha512(str(SECRET_KEY).encode("ascii")).hexdigest()[:16])
|
||||
|
||||
|
@ -367,7 +370,7 @@ MEDIA_URL = "/media/"
|
|||
TEST = False
|
||||
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
|
||||
# We can't check TEST here as its set later by the test runner
|
||||
LOG_LEVEL = CONFIG.y("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
|
||||
LOG_LEVEL = CONFIG.get("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG"
|
||||
# We could add a custom level to stdlib logging and structlog, but it's not easy or clean
|
||||
# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog
|
||||
# Additionally, the entire code uses debug as highest level so that would have to be re-written too
|
||||
|
|
|
@ -31,14 +31,14 @@ class PytestTestRunner: # pragma: no cover
|
|||
|
||||
settings.TEST = True
|
||||
settings.CELERY["task_always_eager"] = True
|
||||
CONFIG.y_set("avatars", "none")
|
||||
CONFIG.y_set("geoip", "tests/GeoLite2-City-Test.mmdb")
|
||||
CONFIG.y_set("blueprints_dir", "./blueprints")
|
||||
CONFIG.y_set(
|
||||
CONFIG.set("avatars", "none")
|
||||
CONFIG.set("geoip", "tests/GeoLite2-City-Test.mmdb")
|
||||
CONFIG.set("blueprints_dir", "./blueprints")
|
||||
CONFIG.set(
|
||||
"outposts.container_image_base",
|
||||
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
|
||||
)
|
||||
CONFIG.y_set("error_reporting.sample_rate", 0)
|
||||
CONFIG.set("error_reporting.sample_rate", 0)
|
||||
sentry_init(
|
||||
environment="testing",
|
||||
send_default_pii=True,
|
||||
|
|
|
@ -136,7 +136,7 @@ class LDAPSource(Source):
|
|||
chmod(private_key_file, 0o600)
|
||||
tls_kwargs["local_private_key_file"] = private_key_file
|
||||
tls_kwargs["local_certificate_file"] = certificate_file
|
||||
if ciphers := CONFIG.y("ldap.tls.ciphers", None):
|
||||
if ciphers := CONFIG.get("ldap.tls.ciphers", None):
|
||||
tls_kwargs["ciphers"] = ciphers.strip()
|
||||
if self.sni:
|
||||
tls_kwargs["sni"] = self.server_uri.split(",", maxsplit=1)[0].strip()
|
||||
|
|
|
@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
|
|||
types_only=False,
|
||||
get_operational_attributes=False,
|
||||
controls=None,
|
||||
paged_size=int(CONFIG.y("ldap.page_size", 50)),
|
||||
paged_size=int(CONFIG.get("ldap.page_size", 50)),
|
||||
paged_criticality=False,
|
||||
):
|
||||
"""Search in pages, returns each page"""
|
||||
|
|
|
@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
|||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=MonitoredTask,
|
||||
soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
|
||||
task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")),
|
||||
soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
||||
task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
||||
)
|
||||
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
|
||||
"""Synchronization of an LDAP Source"""
|
||||
self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours"))
|
||||
self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
|
||||
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
||||
if not source:
|
||||
# Because the source couldn't be found, we don't have a UID
|
||||
|
|
|
@ -13,6 +13,7 @@ from rest_framework.serializers import BaseSerializer
|
|||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.flows.models import Stage
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
@ -104,7 +105,16 @@ class EmailStage(Stage):
|
|||
def backend(self) -> BaseEmailBackend:
|
||||
"""Get fully configured Email Backend instance"""
|
||||
if self.use_global_settings:
|
||||
return self.backend_class()
|
||||
CONFIG.refresh("email.password")
|
||||
return self.backend_class(
|
||||
host=CONFIG.get("email.host"),
|
||||
port=int(CONFIG.get("email.port")),
|
||||
username=CONFIG.get("email.username"),
|
||||
password=CONFIG.get("email.password"),
|
||||
use_tls=CONFIG.get_bool("email.use_tls", False),
|
||||
use_ssl=CONFIG.get_bool("email.use_ssl", False),
|
||||
timeout=int(CONFIG.get("email.timeout")),
|
||||
)
|
||||
return self.backend_class(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
|
|
|
@ -13,6 +13,7 @@ from authentik.flows.models import FlowDesignation, FlowStageBinding, FlowToken
|
|||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.tests import FlowTestCase
|
||||
from authentik.flows.views.executor import QS_KEY_TOKEN, SESSION_KEY_PLAN
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.stages.email.models import EmailStage
|
||||
from authentik.stages.email.stage import PLAN_CONTEXT_EMAIL_OVERRIDE
|
||||
|
||||
|
@ -120,7 +121,7 @@ class TestEmailStage(FlowTestCase):
|
|||
def test_use_global_settings(self):
|
||||
"""Test use_global_settings"""
|
||||
host = "some-unique-string"
|
||||
with self.settings(EMAIL_HOST=host):
|
||||
with CONFIG.patch("email.host", host):
|
||||
self.assertEqual(EmailStage(use_global_settings=True).backend.host, host)
|
||||
|
||||
def test_token(self):
|
||||
|
|
|
@ -78,7 +78,7 @@ class CurrentTenantSerializer(PassiveSerializer):
|
|||
ui_footer_links = ListField(
|
||||
child=FooterLinkSerializer(),
|
||||
read_only=True,
|
||||
default=CONFIG.y("footer_links", []),
|
||||
default=CONFIG.get("footer_links", []),
|
||||
)
|
||||
ui_theme = ChoiceField(
|
||||
choices=Themes.choices,
|
||||
|
|
|
@ -24,7 +24,7 @@ class TestTenants(APITestCase):
|
|||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "authentik",
|
||||
"matched_domain": tenant.domain,
|
||||
"ui_footer_links": CONFIG.y("footer_links"),
|
||||
"ui_footer_links": CONFIG.get("footer_links"),
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
},
|
||||
|
@ -43,7 +43,7 @@ class TestTenants(APITestCase):
|
|||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "custom",
|
||||
"matched_domain": "bar.baz",
|
||||
"ui_footer_links": CONFIG.y("footer_links"),
|
||||
"ui_footer_links": CONFIG.get("footer_links"),
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
},
|
||||
|
@ -59,7 +59,7 @@ class TestTenants(APITestCase):
|
|||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "authentik",
|
||||
"matched_domain": "fallback",
|
||||
"ui_footer_links": CONFIG.y("footer_links"),
|
||||
"ui_footer_links": CONFIG.get("footer_links"),
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
},
|
||||
|
|
|
@ -36,7 +36,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
|
|||
trace = span.to_traceparent()
|
||||
return {
|
||||
"tenant": tenant,
|
||||
"footer_links": CONFIG.y("footer_links"),
|
||||
"footer_links": CONFIG.get("footer_links"),
|
||||
"sentry_trace": trace,
|
||||
"version": get_full_version(),
|
||||
}
|
||||
|
|
|
@ -94,21 +94,21 @@ entries:
|
|||
prompt_data = request.context.get("prompt_data")
|
||||
|
||||
if not request.user.group_attributes(request.http_request).get(
|
||||
USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.y_bool("default_user_change_email", True)
|
||||
USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.get_bool("default_user_change_email", True)
|
||||
):
|
||||
if prompt_data.get("email") != request.user.email:
|
||||
ak_message("Not allowed to change email address.")
|
||||
return False
|
||||
|
||||
if not request.user.group_attributes(request.http_request).get(
|
||||
USER_ATTRIBUTE_CHANGE_NAME, CONFIG.y_bool("default_user_change_name", True)
|
||||
USER_ATTRIBUTE_CHANGE_NAME, CONFIG.get_bool("default_user_change_name", True)
|
||||
):
|
||||
if prompt_data.get("name") != request.user.name:
|
||||
ak_message("Not allowed to change name.")
|
||||
return False
|
||||
|
||||
if not request.user.group_attributes(request.http_request).get(
|
||||
USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.y_bool("default_user_change_username", True)
|
||||
USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.get_bool("default_user_change_username", True)
|
||||
):
|
||||
if prompt_data.get("username") != request.user.username:
|
||||
ak_message("Not allowed to change username.")
|
||||
|
|
|
@ -37,7 +37,7 @@ makedirs(prometheus_tmp_dir, exist_ok=True)
|
|||
max_requests = 1000
|
||||
max_requests_jitter = 50
|
||||
|
||||
_debug = CONFIG.y_bool("DEBUG", False)
|
||||
_debug = CONFIG.get_bool("DEBUG", False)
|
||||
|
||||
logconfig_dict = {
|
||||
"version": 1,
|
||||
|
@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ:
|
|||
else:
|
||||
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
|
||||
|
||||
workers = int(CONFIG.y("web.workers", default_workers))
|
||||
threads = int(CONFIG.y("web.threads", 4))
|
||||
workers = int(CONFIG.get("web.workers", default_workers))
|
||||
threads = int(CONFIG.get("web.threads", 4))
|
||||
|
||||
|
||||
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
||||
|
@ -133,7 +133,7 @@ def pre_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
|||
worker._worker_id = _next_worker_id(server)
|
||||
|
||||
|
||||
if not CONFIG.y_bool("disable_startup_analytics", False):
|
||||
if not CONFIG.get_bool("disable_startup_analytics", False):
|
||||
env = get_env()
|
||||
should_send = env not in ["dev", "ci"]
|
||||
if should_send:
|
||||
|
@ -158,7 +158,7 @@ if not CONFIG.y_bool("disable_startup_analytics", False):
|
|||
except Exception: # nosec
|
||||
pass
|
||||
|
||||
if CONFIG.y_bool("remote_debug"):
|
||||
if CONFIG.get_bool("remote_debug"):
|
||||
import debugpy
|
||||
|
||||
debugpy.listen(("0.0.0.0", 6800)) # nosec
|
||||
|
|
|
@ -52,15 +52,15 @@ def release_lock():
|
|||
|
||||
if __name__ == "__main__":
|
||||
conn = connect(
|
||||
dbname=CONFIG.y("postgresql.name"),
|
||||
user=CONFIG.y("postgresql.user"),
|
||||
password=CONFIG.y("postgresql.password"),
|
||||
host=CONFIG.y("postgresql.host"),
|
||||
port=int(CONFIG.y("postgresql.port")),
|
||||
sslmode=CONFIG.y("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.y("postgresql.sslcert"),
|
||||
sslkey=CONFIG.y("postgresql.sslkey"),
|
||||
dbname=CONFIG.get("postgresql.name"),
|
||||
user=CONFIG.get("postgresql.user"),
|
||||
password=CONFIG.get("postgresql.password"),
|
||||
host=CONFIG.get("postgresql.host"),
|
||||
port=int(CONFIG.get("postgresql.port")),
|
||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||
sslkey=CONFIG.get("postgresql.sslkey"),
|
||||
)
|
||||
curr = conn.cursor()
|
||||
try:
|
||||
|
|
|
@ -25,7 +25,7 @@ class Migration(BaseMigration):
|
|||
# If we already have migrations in the database, assume we're upgrading an existing install
|
||||
# and set the install id to the secret key
|
||||
self.cur.execute(
|
||||
"INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.y("secret_key"),)
|
||||
"INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.get("secret_key"),)
|
||||
)
|
||||
else:
|
||||
# Otherwise assume a new install, generate an install ID based on a UUID
|
||||
|
|
|
@ -108,14 +108,14 @@ class Migration(BaseMigration):
|
|||
self.con.commit()
|
||||
# We also need to clean the cache to make sure no pickeled objects still exist
|
||||
for db in [
|
||||
CONFIG.y("redis.message_queue_db"),
|
||||
CONFIG.y("redis.cache_db"),
|
||||
CONFIG.y("redis.ws_db"),
|
||||
CONFIG.get("redis.message_queue_db"),
|
||||
CONFIG.get("redis.cache_db"),
|
||||
CONFIG.get("redis.ws_db"),
|
||||
]:
|
||||
redis = Redis(
|
||||
host=CONFIG.y("redis.host"),
|
||||
host=CONFIG.get("redis.host"),
|
||||
port=6379,
|
||||
db=db,
|
||||
password=CONFIG.y("redis.password"),
|
||||
password=CONFIG.get("redis.password"),
|
||||
)
|
||||
redis.flushall()
|
||||
|
|
|
@ -14,7 +14,7 @@ from authentik.lib.config import CONFIG
|
|||
CONFIG.log("info", "Starting authentik bootstrap")
|
||||
|
||||
# Sanity check, ensure SECRET_KEY is set before we even check for database connectivity
|
||||
if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0:
|
||||
if CONFIG.get("secret_key") is None or len(CONFIG.get("secret_key")) == 0:
|
||||
CONFIG.log("info", "----------------------------------------------------------------------")
|
||||
CONFIG.log("info", "Secret key missing, check https://goauthentik.io/docs/installation/.")
|
||||
CONFIG.log("info", "----------------------------------------------------------------------")
|
||||
|
@ -24,15 +24,15 @@ if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0:
|
|||
while True:
|
||||
try:
|
||||
conn = connect(
|
||||
dbname=CONFIG.y("postgresql.name"),
|
||||
user=CONFIG.y("postgresql.user"),
|
||||
password=CONFIG.y("postgresql.password"),
|
||||
host=CONFIG.y("postgresql.host"),
|
||||
port=int(CONFIG.y("postgresql.port")),
|
||||
sslmode=CONFIG.y("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.y("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.y("postgresql.sslcert"),
|
||||
sslkey=CONFIG.y("postgresql.sslkey"),
|
||||
dbname=CONFIG.get("postgresql.name"),
|
||||
user=CONFIG.get("postgresql.user"),
|
||||
password=CONFIG.get("postgresql.password"),
|
||||
host=CONFIG.get("postgresql.host"),
|
||||
port=int(CONFIG.get("postgresql.port")),
|
||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||
sslkey=CONFIG.get("postgresql.sslkey"),
|
||||
)
|
||||
conn.cursor()
|
||||
break
|
||||
|
@ -42,12 +42,12 @@ while True:
|
|||
CONFIG.log("info", "PostgreSQL connection successful")
|
||||
|
||||
REDIS_PROTOCOL_PREFIX = "redis://"
|
||||
if CONFIG.y_bool("redis.tls", False):
|
||||
if CONFIG.get_bool("redis.tls", False):
|
||||
REDIS_PROTOCOL_PREFIX = "rediss://"
|
||||
REDIS_URL = (
|
||||
f"{REDIS_PROTOCOL_PREFIX}:"
|
||||
f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:"
|
||||
f"{int(CONFIG.y('redis.port'))}/{CONFIG.y('redis.db')}"
|
||||
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
||||
f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"""Test Enroll flow"""
|
||||
from time import sleep
|
||||
|
||||
from django.test import override_settings
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as ec
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
|
@ -9,6 +8,7 @@ from selenium.webdriver.support.wait import WebDriverWait
|
|||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import User
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.stages.identification.models import IdentificationStage
|
||||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
@ -56,7 +56,7 @@ class TestFlowsEnroll(SeleniumTestCase):
|
|||
@apply_blueprint(
|
||||
"example/flows-enrollment-email-verification.yaml",
|
||||
)
|
||||
@override_settings(EMAIL_PORT=1025)
|
||||
@CONFIG.patch("email.port", 1025)
|
||||
def test_enroll_email(self):
|
||||
"""Test enroll with Email verification"""
|
||||
# Attach enrollment flow to identification stage
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"""Test recovery flow"""
|
||||
from time import sleep
|
||||
|
||||
from django.test import override_settings
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as ec
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
|
@ -10,6 +9,7 @@ from authentik.blueprints.tests import apply_blueprint
|
|||
from authentik.core.models import User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.stages.identification.models import IdentificationStage
|
||||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
@ -47,7 +47,7 @@ class TestFlowsRecovery(SeleniumTestCase):
|
|||
@apply_blueprint(
|
||||
"example/flows-recovery-email-verification.yaml",
|
||||
)
|
||||
@override_settings(EMAIL_PORT=1025)
|
||||
@CONFIG.patch("email.port", 1025)
|
||||
def test_recover_email(self):
|
||||
"""Test recovery with Email verification"""
|
||||
# Attach recovery flow to identification stage
|
||||
|
|
Reference in a new issue