diff --git a/server/src/main/java/access/api/DefaultErrorController.java b/server/src/main/java/access/api/DefaultErrorController.java index b51424c2..bbefcdbc 100644 --- a/server/src/main/java/access/api/DefaultErrorController.java +++ b/server/src/main/java/access/api/DefaultErrorController.java @@ -39,7 +39,7 @@ public DefaultErrorController(ErrorAttributes errorAttributes) { } @RequestMapping("/error") - public ResponseEntity error(HttpServletRequest request) throws URISyntaxException { + public ResponseEntity error(HttpServletRequest request) { WebRequest webRequest = new ServletWebRequest(request); Map result = this.errorAttributes.getErrorAttributes( webRequest, diff --git a/server/src/main/java/access/security/AuthorizationRequestCustomizer.java b/server/src/main/java/access/security/AuthorizationRequestCustomizer.java new file mode 100644 index 00000000..9841b4ed --- /dev/null +++ b/server/src/main/java/access/security/AuthorizationRequestCustomizer.java @@ -0,0 +1,50 @@ +package access.security; + +import access.model.Invitation; +import access.repository.InvitationRepository; +import jakarta.servlet.http.HttpSession; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.web.savedrequest.DefaultSavedRequest; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +import java.util.Optional; +import java.util.function.Consumer; + +public class AuthorizationRequestCustomizer implements Consumer { + + private final InvitationRepository invitationRepository; + private final String eduidEntityId; + + public AuthorizationRequestCustomizer(InvitationRepository invitationRepository, String eduidEntityId) { + this.invitationRepository = invitationRepository; + this.eduidEntityId = eduidEntityId; + } + + @Override + public void accept(OAuth2AuthorizationRequest.Builder builder) { + builder.additionalParameters(params -> { + RequestAttributes requestAttributes = RequestContextHolder.currentRequestAttributes(); + HttpSession session = ((ServletRequestAttributes) requestAttributes) + .getRequest().getSession(false); + if (session == null) { + return; + } + DefaultSavedRequest savedRequest = (DefaultSavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST"); + String[] force = savedRequest.getParameterValues("force"); + if (force != null && force.length == 1) { + params.put("prompt", "login"); + } + String[] hash = savedRequest.getParameterValues("hash"); + if (hash != null && hash.length == 1) { + Optional optionalInvitation = invitationRepository.findByHash(hash[0]); + optionalInvitation.ifPresent(invitation -> { + if (invitation.isEduIDOnly()) { + params.put("login_hint", eduidEntityId); + } + }); + } + }); + } +} diff --git a/server/src/main/java/access/security/SecurityConfig.java b/server/src/main/java/access/security/SecurityConfig.java index b4ded311..400aa25e 100644 --- a/server/src/main/java/access/security/SecurityConfig.java +++ b/server/src/main/java/access/security/SecurityConfig.java @@ -158,37 +158,10 @@ private OAuth2AuthorizationRequestResolver authorizationRequestResolver( new DefaultOAuth2AuthorizationRequestResolver( clientRegistrationRepository, "/oauth2/authorization"); authorizationRequestResolver.setAuthorizationRequestCustomizer( - authorizationRequestCustomizer()); - + new AuthorizationRequestCustomizer(invitationRepository, eduidEntityId)); return authorizationRequestResolver; } - private Consumer authorizationRequestCustomizer() { - return customizer -> customizer - .additionalParameters(params -> { - RequestAttributes requestAttributes = RequestContextHolder.currentRequestAttributes(); - HttpSession session = ((ServletRequestAttributes) requestAttributes) - .getRequest().getSession(false); - if (session == null) { - return; - } - DefaultSavedRequest savedRequest = (DefaultSavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST"); - String[] force = savedRequest.getParameterValues("force"); - if (force != null && force.length == 1) { - params.put("prompt", "login"); - } - String[] hash = savedRequest.getParameterValues("hash"); - if (hash != null && hash.length == 1) { - Optional optionalInvitation = invitationRepository.findByHash(hash[0]); - optionalInvitation.ifPresent(invitation -> { - if (invitation.isEduIDOnly()) { - params.put("login_hint", eduidEntityId); - } - }); - } - }); - } - @Bean @Order(2) SecurityFilterChain basicAuthenticationSecurityFilterChain(HttpSecurity http) throws Exception {