package radius import ( "context" "net" "sort" "sync" log "github.com/sirupsen/logrus" "goauthentik.io/internal/config" "goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/radius/metrics" "layeh.com/radius" ) type ProviderInstance struct { ClientNetworks []*net.IPNet SharedSecret []byte appSlug string flowSlug string s *RadiusServer log *log.Entry } type RadiusServer struct { s radius.PacketServer log *log.Entry ac *ak.APIController providers []*ProviderInstance } func NewServer(ac *ak.APIController) *RadiusServer { rs := &RadiusServer{ log: log.WithField("logger", "authentik.outpost.radius"), ac: ac, providers: []*ProviderInstance{}, } rs.s = radius.PacketServer{ Handler: rs, SecretSource: rs, Addr: config.Get().Listen.Radius, } return rs } type cpp struct { c *net.IPNet p *ProviderInstance } func (rs *RadiusServer) RADIUSSecret(ctx context.Context, remoteAddr net.Addr) ([]byte, error) { matchedPrefixes := []cpp{} host, _, err := net.SplitHostPort(remoteAddr.String()) if err != nil { rs.log.WithError(err).Warning("Failed to get remote IP") return nil, err } ip := net.ParseIP(host) // Check all networks we have and remember all matching prefixes for _, p := range rs.providers { for _, cidr := range p.ClientNetworks { if cidr.Contains(ip) { matchedPrefixes = append(matchedPrefixes, cpp{ c: cidr, p: p, }) } } } if len(matchedPrefixes) < 1 { rs.log.WithField("ip", ip.String()).Warning("No matching prefixes found") return []byte{}, nil } // Sort matched cidrs by prefix length sort.Slice(matchedPrefixes, func(i, j int) bool { _, bi := matchedPrefixes[i].c.Mask.Size() _, bj := matchedPrefixes[j].c.Mask.Size() return bi < bj }) candidate := matchedPrefixes[0] rs.log.WithField("ip", ip.String()).WithField("cidr", candidate.c.String()).Debug("Matched CIDR") return candidate.p.SharedSecret, nil } func (rs *RadiusServer) Start() error { wg := sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() metrics.RunServer() }() go func() { defer wg.Done() err := rs.StartRadiusServer() if err != nil { panic(err) } }() wg.Wait() return nil } func (rs *RadiusServer) Stop() error { ctx, cancel := context.WithCancel(context.Background()) err := rs.s.Shutdown(ctx) cancel() return err } func (rs *RadiusServer) TimerFlowCacheExpiry(context.Context) {} func (rs *RadiusServer) Type() string { return "radius" }