diff --git a/musician/views.py b/musician/views.py index ce5e52c..7f7e7c3 100644 --- a/musician/views.py +++ b/musician/views.py @@ -2,6 +2,7 @@ from django.http import HttpResponseRedirect from django.shortcuts import render from django.urls import reverse_lazy +from django.utils.http import is_safe_url from django.views.generic.base import RedirectView, TemplateView from django.views.generic.edit import FormView @@ -48,6 +49,7 @@ class LoginView(FormView): template_name = 'auth/login.html' form_class = LoginForm success_url = reverse_lazy('musician:dashboard') + redirect_field_name = 'next' extra_context = {'version': get_version()} def get_form_kwargs(self): @@ -60,6 +62,31 @@ class LoginView(FormView): auth_login(self.request, form.username, form.token) return HttpResponseRedirect(self.get_success_url()) + def get_success_url(self): + url = self.get_redirect_url() + return url or self.success_url + + def get_redirect_url(self): + """Return the user-originating redirect URL if it's safe.""" + redirect_to = self.request.POST.get( + self.redirect_field_name, + self.request.GET.get(self.redirect_field_name, '') + ) + url_is_safe = is_safe_url( + url=redirect_to, + allowed_hosts={self.request.get_host()}, + require_https=self.request.is_secure(), + ) + return redirect_to if url_is_safe else '' + + def get_context_data(self, **kwargs): + context = super().get_context_data(**kwargs) + context.update({ + self.redirect_field_name: self.get_redirect_url(), + **(self.extra_context or {}) + }) + return context + class LogoutView(RedirectView): """