This repository has been archived on 2024-05-31. You can view files and clone it, but cannot push or open issues or pull requests.
authentik/internal/web/brand_tls/brand_tls.go
Marc 'risson' Schmitt 77d8877efe
tenants -> brands, init new tenant model, migrate some config to tenants
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2023-11-21 18:23:58 +01:00

87 lines
1.8 KiB
Go

package brand_tls
import (
"crypto/tls"
"strings"
"time"
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
"goauthentik.io/internal/crypto"
"goauthentik.io/internal/outpost/ak"
)
type Watcher struct {
client *api.APIClient
log *log.Entry
cs *ak.CryptoStore
fallback *tls.Certificate
brands []api.Brand
}
func NewWatcher(client *api.APIClient) *Watcher {
cs := ak.NewCryptoStore(client.CryptoApi)
l := log.WithField("logger", "authentik.router.brand_tls")
cert, err := crypto.GenerateSelfSignedCert()
if err != nil {
l.WithError(err).Error("failed to generate default cert")
}
return &Watcher{
client: client,
log: l,
cs: cs,
fallback: &cert,
}
}
func (w *Watcher) Start() {
ticker := time.NewTicker(time.Minute * 3)
w.log.Info("Starting Brand TLS Checker")
for ; true; <-ticker.C {
w.Check()
}
}
func (w *Watcher) Check() {
w.log.Info("updating brand certificates")
brands, _, err := w.client.CoreApi.CoreBrandsListExecute(api.ApiCoreBrandsListRequest{})
if err != nil {
w.log.WithError(err).Warning("failed to get brands")
return
}
for _, t := range brands.Results {
if kp := t.WebCertificate.Get(); kp != nil {
err := w.cs.AddKeypair(*kp)
if err != nil {
w.log.WithError(err).Warning("failed to add certificate")
}
}
}
w.brands = brands.Results
}
func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
var bestSelection *api.Brand
for _, t := range w.brands {
if t.WebCertificate.Get() == nil {
continue
}
if *t.Default {
bestSelection = &t
}
if strings.HasSuffix(ch.ServerName, t.Domain) {
bestSelection = &t
}
}
if bestSelection == nil {
return w.fallback, nil
}
cert := w.cs.Get(*bestSelection.WebCertificate.Get())
if cert == nil {
return w.fallback, nil
}
return cert, nil
}