events: rewrite GeoIP to a wrapper, reload file every 8 hours
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
f5dbdbd48b
commit
17326615b7
|
@ -43,7 +43,7 @@ class ConfigView(APIView):
|
||||||
deb_test = settings.DEBUG or settings.TEST
|
deb_test = settings.DEBUG or settings.TEST
|
||||||
if path.ismount(settings.MEDIA_ROOT) or deb_test:
|
if path.ismount(settings.MEDIA_ROOT) or deb_test:
|
||||||
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
||||||
if GEOIP_READER:
|
if GEOIP_READER.enabled:
|
||||||
caps.append(Capabilities.CAN_GEO_IP)
|
caps.append(Capabilities.CAN_GEO_IP)
|
||||||
return caps
|
return caps
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
from typing import Optional, TypedDict
|
from typing import Optional, TypedDict
|
||||||
|
|
||||||
from django_filters.rest_framework import DjangoFilterBackend
|
from django_filters.rest_framework import DjangoFilterBackend
|
||||||
from geoip2.errors import GeoIP2Error
|
|
||||||
from guardian.utils import get_anonymous_user
|
from guardian.utils import get_anonymous_user
|
||||||
from rest_framework import mixins
|
from rest_framework import mixins
|
||||||
from rest_framework.fields import SerializerMethodField
|
from rest_framework.fields import SerializerMethodField
|
||||||
|
@ -13,7 +12,7 @@ from rest_framework.viewsets import GenericViewSet
|
||||||
from ua_parser import user_agent_parser
|
from ua_parser import user_agent_parser
|
||||||
|
|
||||||
from authentik.core.models import AuthenticatedSession
|
from authentik.core.models import AuthenticatedSession
|
||||||
from authentik.events.geo import GEOIP_READER
|
from authentik.events.geo import GEOIP_READER, GeoIPDict
|
||||||
|
|
||||||
|
|
||||||
class UserAgentDeviceDict(TypedDict):
|
class UserAgentDeviceDict(TypedDict):
|
||||||
|
@ -52,15 +51,6 @@ class UserAgentDict(TypedDict):
|
||||||
string: str
|
string: str
|
||||||
|
|
||||||
|
|
||||||
class GeoIPDict(TypedDict):
|
|
||||||
"""GeoIP Details"""
|
|
||||||
|
|
||||||
continent: str
|
|
||||||
country: str
|
|
||||||
lat: float
|
|
||||||
long: float
|
|
||||||
|
|
||||||
|
|
||||||
class AuthenticatedSessionSerializer(ModelSerializer):
|
class AuthenticatedSessionSerializer(ModelSerializer):
|
||||||
"""AuthenticatedSession Serializer"""
|
"""AuthenticatedSession Serializer"""
|
||||||
|
|
||||||
|
@ -81,18 +71,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
||||||
self, instance: AuthenticatedSession
|
self, instance: AuthenticatedSession
|
||||||
) -> Optional[GeoIPDict]: # pragma: no cover
|
) -> Optional[GeoIPDict]: # pragma: no cover
|
||||||
"""Get parsed user agent"""
|
"""Get parsed user agent"""
|
||||||
if not GEOIP_READER:
|
return GEOIP_READER.city_dict(instance.last_ip)
|
||||||
return None
|
|
||||||
try:
|
|
||||||
city = GEOIP_READER.city(instance.last_ip)
|
|
||||||
return {
|
|
||||||
"continent": city.continent.code,
|
|
||||||
"country": city.country.iso_code,
|
|
||||||
"lat": city.location.latitude,
|
|
||||||
"long": city.location.longitude,
|
|
||||||
}
|
|
||||||
except (GeoIP2Error, ValueError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
"""events GeoIP Reader"""
|
"""events GeoIP Reader"""
|
||||||
from typing import Optional
|
from datetime import datetime
|
||||||
|
from os import stat
|
||||||
|
from time import time
|
||||||
|
from typing import Optional, TypedDict
|
||||||
|
|
||||||
from geoip2.database import Reader
|
from geoip2.database import Reader
|
||||||
|
from geoip2.errors import GeoIP2Error
|
||||||
|
from geoip2.models import City
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
|
@ -9,17 +14,78 @@ from authentik.lib.config import CONFIG
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def get_geoip_reader() -> Optional[Reader]:
|
class GeoIPDict(TypedDict):
|
||||||
"""Get GeoIP Reader, if configured, otherwise none"""
|
"""GeoIP Details"""
|
||||||
path = CONFIG.y("authentik.geoip")
|
|
||||||
if path == "" or not path:
|
continent: str
|
||||||
return None
|
country: str
|
||||||
try:
|
lat: float
|
||||||
reader = Reader(path)
|
long: float
|
||||||
LOGGER.info("Enabled GeoIP support")
|
city: str
|
||||||
return reader
|
|
||||||
except OSError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
GEOIP_READER = get_geoip_reader()
|
class GeoIPReader:
|
||||||
|
"""Slim wrapper around GeoIP API"""
|
||||||
|
|
||||||
|
__reader: Optional[Reader] = None
|
||||||
|
__last_mtime: float = 0.0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.__open()
|
||||||
|
|
||||||
|
def __open(self):
|
||||||
|
"""Get GeoIP Reader, if configured, otherwise none"""
|
||||||
|
path = CONFIG.y("authentik.geoip")
|
||||||
|
if path == "" or not path:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
reader = Reader(path)
|
||||||
|
LOGGER.info("Loaded GeoIP database")
|
||||||
|
self.__reader = reader
|
||||||
|
self.__last_mtime = stat(path).st_mtime
|
||||||
|
except OSError as exc:
|
||||||
|
LOGGER.warning("Failed to load GeoIP database", exc=exc)
|
||||||
|
|
||||||
|
def __check_expired(self):
|
||||||
|
"""Check if the geoip database has been opened longer than 8 hours,
|
||||||
|
and re-open it, as it will probably will have been re-downloaded"""
|
||||||
|
now = time()
|
||||||
|
diff = datetime.fromtimestamp(now) - datetime.fromtimestamp(self.__last_mtime)
|
||||||
|
diff_hours = diff.total_seconds() // 3600
|
||||||
|
if diff_hours >= 8:
|
||||||
|
LOGGER.info("GeoIP databased loaded too long, re-opening", diff=diff)
|
||||||
|
self.__open()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
"""Check if GeoIP is enabled"""
|
||||||
|
return bool(self.__reader)
|
||||||
|
|
||||||
|
def city(self, ip_address: str) -> Optional[City]:
|
||||||
|
"""Wrapper for Reader.city"""
|
||||||
|
if not self.enabled:
|
||||||
|
return None
|
||||||
|
self.__check_expired()
|
||||||
|
try:
|
||||||
|
return self.__reader.city(ip_address)
|
||||||
|
except (GeoIP2Error, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def city_dict(self, ip_address: str) -> Optional[GeoIPDict]:
|
||||||
|
"""Wrapper for self.city that returns a dict"""
|
||||||
|
city = self.city(ip_address)
|
||||||
|
if not city:
|
||||||
|
return None
|
||||||
|
city_dict: GeoIPDict = {
|
||||||
|
"continent": city.continent.code,
|
||||||
|
"country": city.country.iso_code,
|
||||||
|
"lat": city.location.latitude,
|
||||||
|
"long": city.location.longitude,
|
||||||
|
"city": "",
|
||||||
|
}
|
||||||
|
if city.city.name:
|
||||||
|
city_dict["city"] = city.city.name
|
||||||
|
return city_dict
|
||||||
|
|
||||||
|
|
||||||
|
GEOIP_READER = GeoIPReader()
|
||||||
|
|
|
@ -10,7 +10,6 @@ from django.db import models
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from geoip2.errors import GeoIP2Error
|
|
||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge
|
||||||
from requests import RequestException, post
|
from requests import RequestException, post
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
@ -160,20 +159,10 @@ class Event(ExpiringModel):
|
||||||
|
|
||||||
def with_geoip(self): # pragma: no cover
|
def with_geoip(self): # pragma: no cover
|
||||||
"""Apply GeoIP Data, when enabled"""
|
"""Apply GeoIP Data, when enabled"""
|
||||||
if not GEOIP_READER:
|
city = GEOIP_READER.city_dict(self.client_ip)
|
||||||
|
if not city:
|
||||||
return
|
return
|
||||||
try:
|
self.context["geo"] = city
|
||||||
response = GEOIP_READER.city(self.client_ip)
|
|
||||||
self.context["geo"] = {
|
|
||||||
"continent": response.continent.code,
|
|
||||||
"country": response.country.iso_code,
|
|
||||||
"lat": response.location.latitude,
|
|
||||||
"long": response.location.longitude,
|
|
||||||
}
|
|
||||||
if response.city.name:
|
|
||||||
self.context["geo"]["city"] = response.city.name
|
|
||||||
except (GeoIP2Error, ValueError) as exc:
|
|
||||||
LOGGER.warning("Failed to add geoIP Data to event", exc=exc)
|
|
||||||
|
|
||||||
def _set_prom_metrics(self):
|
def _set_prom_metrics(self):
|
||||||
GAUGE_EVENTS.labels(
|
GAUGE_EVENTS.labels(
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Test GeoIP Wrapper"""
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from authentik.events.geo import GeoIPReader
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeoIP(TestCase):
|
||||||
|
"""Test GeoIP Wrapper"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.reader = GeoIPReader()
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
"""Test simple city wrapper"""
|
||||||
|
# IPs from
|
||||||
|
# https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
|
||||||
|
self.assertEqual(
|
||||||
|
self.reader.city_dict("2.125.160.216"),
|
||||||
|
{
|
||||||
|
"city": "Boxford",
|
||||||
|
"continent": "EU",
|
||||||
|
"country": "GB",
|
||||||
|
"lat": 51.75,
|
||||||
|
"long": -1.25,
|
||||||
|
},
|
||||||
|
)
|
|
@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from geoip2.errors import GeoIP2Error
|
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.events.geo import GEOIP_READER
|
from authentik.events.geo import GEOIP_READER
|
||||||
|
@ -39,16 +38,12 @@ class PolicyRequest:
|
||||||
def set_http_request(self, request: HttpRequest): # pragma: no cover
|
def set_http_request(self, request: HttpRequest): # pragma: no cover
|
||||||
"""Load data from HTTP request, including geoip when enabled"""
|
"""Load data from HTTP request, including geoip when enabled"""
|
||||||
self.http_request = request
|
self.http_request = request
|
||||||
if not GEOIP_READER:
|
if not GEOIP_READER.enabled:
|
||||||
return
|
return
|
||||||
try:
|
client_ip = get_client_ip(request)
|
||||||
client_ip = get_client_ip(request)
|
if not client_ip:
|
||||||
if not client_ip:
|
return
|
||||||
return
|
self.context["geoip"] = GEOIP_READER.city(client_ip)
|
||||||
response = GEOIP_READER.city(client_ip)
|
|
||||||
self.context["geoip"] = response
|
|
||||||
except (GeoIP2Error, ValueError) as exc:
|
|
||||||
LOGGER.warning("failed to get geoip data", exc=exc)
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
text = f"<PolicyRequest user={self.user}"
|
text = f"<PolicyRequest user={self.user}"
|
||||||
|
|
|
@ -14,6 +14,7 @@ class PytestTestRunner: # pragma: no cover
|
||||||
settings.TEST = True
|
settings.TEST = True
|
||||||
settings.CELERY_TASK_ALWAYS_EAGER = True
|
settings.CELERY_TASK_ALWAYS_EAGER = True
|
||||||
CONFIG.y_set("authentik.avatars", "none")
|
CONFIG.y_set("authentik.avatars", "none")
|
||||||
|
CONFIG.y_set("authentik.geoip", "tests/GeoLite2-City-Test.mmdb")
|
||||||
|
|
||||||
def run_tests(self, test_labels):
|
def run_tests(self, test_labels):
|
||||||
"""Run pytest and return the exitcode.
|
"""Run pytest and return the exitcode.
|
||||||
|
|
Reference in New Issue