From d346c01944ef3656e4401d108651002e381296cf Mon Sep 17 00:00:00 2001 From: Hiroki Terashima Date: Tue, 9 Apr 2024 08:52:00 -0700 Subject: [PATCH] Refactor Google/Microsoft OpenIdConnectFilter. Extract common code to Abstract class (#267) --- .../filters/AbstractOpenIdConnectFilter.java | 88 +++++++++++++++ .../filters/GoogleOpenIdConnectFilter.java | 104 +++++------------- .../filters/MicrosoftOpenIdConnectFilter.java | 76 +++---------- 3 files changed, 130 insertions(+), 138 deletions(-) create mode 100644 src/main/java/org/wise/portal/presentation/web/filters/AbstractOpenIdConnectFilter.java diff --git a/src/main/java/org/wise/portal/presentation/web/filters/AbstractOpenIdConnectFilter.java b/src/main/java/org/wise/portal/presentation/web/filters/AbstractOpenIdConnectFilter.java new file mode 100644 index 000000000..a0e4e01b2 --- /dev/null +++ b/src/main/java/org/wise/portal/presentation/web/filters/AbstractOpenIdConnectFilter.java @@ -0,0 +1,88 @@ +package org.wise.portal.presentation.web.filters; + +import java.net.URL; +import java.security.interfaces.RSAPublicKey; +import java.util.Date; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.jwt.crypto.sign.RsaVerifier; +import org.springframework.security.oauth2.client.OAuth2RestTemplate; +import org.springframework.security.oauth2.common.OAuth2AccessToken; +import org.springframework.security.oauth2.common.exceptions.OAuth2Exception; +import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; +import org.wise.portal.service.authentication.UserDetailsService; + +import com.auth0.jwk.Jwk; +import com.auth0.jwk.JwkProvider; +import com.auth0.jwk.UrlJwkProvider; + +public abstract class AbstractOpenIdConnectFilter extends AbstractAuthenticationProcessingFilter { + + protected String clientId; + protected String issuer; + protected String jwkUrl; + protected OAuth2RestTemplate openIdRestTemplate; + + @Autowired + protected UserDetailsService userDetailsService; + + protected AbstractOpenIdConnectFilter(String defaultFilterProcessesUrl) { + super(defaultFilterProcessesUrl); + setAuthenticationManager(new NoopAuthenticationManager()); + } + + protected OAuth2AccessToken getAccessToken() { + OAuth2AccessToken accessToken; + try { + accessToken = openIdRestTemplate.getAccessToken(); + } catch (final OAuth2Exception e) { + throw new BadCredentialsException("Could not obtain access token", e); + } + return accessToken; + } + + protected void saveRequestParams(HttpServletRequest request) { + saveRequestParameter(request, "accessCode"); + saveRequestParameter(request, "redirectUrl"); + } + + protected void saveRequestParameter(HttpServletRequest request, String parameterName) { + String parameterValue = request.getParameter(parameterName); + String parameterFromState = (String) openIdRestTemplate.getOAuth2ClientContext() + .removePreservedState(parameterName); + openIdRestTemplate.getOAuth2ClientContext().setPreservedState(parameterName, parameterValue); + request.setAttribute(parameterName, parameterFromState); + } + + protected void verifyClaims(Map claims) { + int exp = (int) claims.get("exp"); + Date expireDate = new Date(exp * 1000L); + Date now = new Date(); + if (expireDate.before(now) || !claims.get("iss").equals(issuer) + || !claims.get("aud").equals(clientId)) { + throw new RuntimeException("Invalid claims"); + } + } + + protected RsaVerifier verifier(String kid) throws Exception { + JwkProvider provider = new UrlJwkProvider(new URL(jwkUrl)); + Jwk jwk = provider.get(kid); + return new RsaVerifier((RSAPublicKey) jwk.getPublicKey()); + } + + protected void invalidateAccessToken() { + openIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null); + } + + protected abstract void setClientId(String clientId); + + protected abstract void setIssuer(String issuer); + + protected abstract void setJwkUrl(String jwkUrl); + + protected abstract void setOpenIdRestTemplate(OAuth2RestTemplate template); +} diff --git a/src/main/java/org/wise/portal/presentation/web/filters/GoogleOpenIdConnectFilter.java b/src/main/java/org/wise/portal/presentation/web/filters/GoogleOpenIdConnectFilter.java index 72b5a0dab..96ceb8445 100644 --- a/src/main/java/org/wise/portal/presentation/web/filters/GoogleOpenIdConnectFilter.java +++ b/src/main/java/org/wise/portal/presentation/web/filters/GoogleOpenIdConnectFilter.java @@ -24,9 +24,6 @@ package org.wise.portal.presentation.web.filters; import java.io.IOException; -import java.net.URL; -import java.security.interfaces.RSAPublicKey; -import java.util.Date; import java.util.Map; import javax.servlet.FilterChain; @@ -34,15 +31,11 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import com.auth0.jwk.Jwk; -import com.auth0.jwk.JwkProvider; -import com.auth0.jwk.UrlJwkProvider; import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; -import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; @@ -50,52 +43,25 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.jwt.Jwt; import org.springframework.security.jwt.JwtHelper; -import org.springframework.security.jwt.crypto.sign.RsaVerifier; import org.springframework.security.oauth2.client.OAuth2RestTemplate; import org.springframework.security.oauth2.common.OAuth2AccessToken; -import org.springframework.security.oauth2.common.exceptions.OAuth2Exception; -import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.wise.portal.domain.authentication.MutableUserDetails; -import org.wise.portal.service.authentication.UserDetailsService; import org.wise.portal.service.session.SessionService; -public class GoogleOpenIdConnectFilter extends AbstractAuthenticationProcessingFilter { - - @Value("${google.clientId:}") - private String googleClientId; - - @Value("${google.issuer:}") - private String googleIssuer; - - @Value("${google.jwkUrl:}") - private String googleJwkUrl; - - @Autowired - @Qualifier("googleOpenIdRestTemplate") - private OAuth2RestTemplate googleOpenIdRestTemplate; - - @Autowired - private UserDetailsService userDetailsService; +public class GoogleOpenIdConnectFilter extends AbstractOpenIdConnectFilter { @Autowired protected SessionService sessionService; public GoogleOpenIdConnectFilter(String defaultFilterProcessesUrl) { super(defaultFilterProcessesUrl); - setAuthenticationManager(new NoopAuthenticationManager()); } @Override public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException { - saveRequestParameter(request, "accessCode"); - saveRequestParameter(request, "redirectUrl"); - OAuth2AccessToken accessToken; - try { - accessToken = googleOpenIdRestTemplate.getAccessToken(); - } catch (final OAuth2Exception e) { - throw new BadCredentialsException("Could not obtain access token", e); - } + saveRequestParams(request); + OAuth2AccessToken accessToken = getAccessToken(); try { final String idToken = accessToken.getAdditionalInformation().get("id_token").toString(); String kid = JwtHelper.headers(idToken).get("kid"); @@ -103,8 +69,7 @@ public Authentication attemptAuthentication(HttpServletRequest request, final Map authInfo = new ObjectMapper().readValue(tokenDecoded.getClaims(), Map.class); verifyClaims(authInfo); - String googleUserId = authInfo.get("sub"); - final UserDetails user = userDetailsService.loadUserByGoogleUserId(googleUserId); + final UserDetails user = userDetailsService.loadUserByGoogleUserId(authInfo.get("sub")); invalidateAccesToken(); if (user != null) { return new UsernamePasswordAuthenticationToken(user, null, user.getAuthorities()); @@ -116,47 +81,8 @@ public Authentication attemptAuthentication(HttpServletRequest request, } } - private void saveRequestParameter(HttpServletRequest request, String parameterName) { - String parameterValue = request.getParameter(parameterName); - String parameterFromState = (String) googleOpenIdRestTemplate.getOAuth2ClientContext() - .removePreservedState(parameterName); - googleOpenIdRestTemplate.getOAuth2ClientContext().setPreservedState(parameterName, - parameterValue); - request.setAttribute(parameterName, parameterFromState); - } - private void invalidateAccesToken() { - googleOpenIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null); - } - - public void verifyClaims(Map claims) { - int exp = (int) claims.get("exp"); - Date expireDate = new Date(exp * 1000L); - Date now = new Date(); - if (expireDate.before(now) || !claims.get("iss").equals(googleIssuer) - || !claims.get("aud").equals(googleClientId)) { - throw new RuntimeException("Invalid claims"); - } - } - - private RsaVerifier verifier(String kid) throws Exception { - JwkProvider provider = new UrlJwkProvider(new URL(googleJwkUrl)); - Jwk jwk = provider.get(kid); - return new RsaVerifier((RSAPublicKey) jwk.getPublicKey()); - } - - public void setRestTemplate(OAuth2RestTemplate restTemplate2) { - googleOpenIdRestTemplate = restTemplate2; - } - - private static class NoopAuthenticationManager implements AuthenticationManager { - - @Override - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { - throw new UnsupportedOperationException( - "No authentication should be done with this AuthenticationManager"); - } + openIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null); } @Override @@ -168,4 +94,24 @@ protected void successfulAuthentication(HttpServletRequest request, HttpServletR super.successfulAuthentication(request, response, chain, authentication); } + @Value("${google.clientId:}") + protected void setClientId(String clientId) { + this.clientId = clientId; + } + + @Value("${google.issuer:}") + protected void setIssuer(String issuer) { + this.issuer = issuer; + } + + @Value("${google.jwkUrl:}") + protected void setJwkUrl(String jwkUrl) { + this.jwkUrl = jwkUrl; + } + + @Autowired + @Qualifier("googleOpenIdRestTemplate") + protected void setOpenIdRestTemplate(OAuth2RestTemplate template) { + this.openIdRestTemplate = template; + } } diff --git a/src/main/java/org/wise/portal/presentation/web/filters/MicrosoftOpenIdConnectFilter.java b/src/main/java/org/wise/portal/presentation/web/filters/MicrosoftOpenIdConnectFilter.java index 61a198d8b..28f549dd6 100644 --- a/src/main/java/org/wise/portal/presentation/web/filters/MicrosoftOpenIdConnectFilter.java +++ b/src/main/java/org/wise/portal/presentation/web/filters/MicrosoftOpenIdConnectFilter.java @@ -1,9 +1,6 @@ package org.wise.portal.presentation.web.filters; import java.io.IOException; -import java.net.URL; -import java.security.interfaces.RSAPublicKey; -import java.util.Date; import java.util.Map; import javax.servlet.ServletException; @@ -13,59 +10,29 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; -import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.jwt.Jwt; import org.springframework.security.jwt.JwtHelper; -import org.springframework.security.jwt.crypto.sign.RsaVerifier; import org.springframework.security.oauth2.client.OAuth2RestTemplate; import org.springframework.security.oauth2.common.OAuth2AccessToken; -import org.springframework.security.oauth2.common.exceptions.OAuth2Exception; -import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.wise.portal.presentation.web.exception.MicrosoftUserNotFoundException; -import org.wise.portal.service.authentication.UserDetailsService; -import com.auth0.jwk.Jwk; -import com.auth0.jwk.JwkProvider; -import com.auth0.jwk.UrlJwkProvider; import com.fasterxml.jackson.databind.ObjectMapper; -public class MicrosoftOpenIdConnectFilter extends AbstractAuthenticationProcessingFilter { - - @Value("${microsoft.clientId:}") - private String microsoftClientId; - - @Value("${microsoft.issuer:}") - private String microsoftIssuer; - - @Value("${microsoft.jwkUrl:}") - private String microsoftJwkUrl; - - @Autowired - @Qualifier("microsoftOpenIdRestTemplate") - private OAuth2RestTemplate microsoftOpenIdRestTemplate; - - @Autowired - private UserDetailsService userDetailsService; +public class MicrosoftOpenIdConnectFilter extends AbstractOpenIdConnectFilter { public MicrosoftOpenIdConnectFilter(String defaultFilterProcessesUrl) { super(defaultFilterProcessesUrl); - setAuthenticationManager(new NoopAuthenticationManager()); } @Override public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException, ServletException { - saveRequestParameter(request, "redirectUrl"); - OAuth2AccessToken accessToken; - try { - accessToken = microsoftOpenIdRestTemplate.getAccessToken(); - } catch (final OAuth2Exception e) { - throw new BadCredentialsException("Could not obtain access token", e); - } + saveRequestParams(request); + OAuth2AccessToken accessToken = getAccessToken(); final String idToken = accessToken.getAdditionalInformation().get("id_token").toString(); String kid = JwtHelper.headers(idToken).get("kid"); Jwt tokenDecoded = null; @@ -78,8 +45,7 @@ public Authentication attemptAuthentication(HttpServletRequest request, final Map authInfo = new ObjectMapper().readValue(tokenDecoded.getClaims(), Map.class); verifyClaims(authInfo); - String microsoftUserId = authInfo.get("sub"); - final UserDetails user = userDetailsService.loadUserByMicrosoftUserId(microsoftUserId); + final UserDetails user = userDetailsService.loadUserByMicrosoftUserId(authInfo.get("sub")); invalidateAccessToken(); if (user != null) { if (request.getAttribute("redirectUrl").toString().contains("join")) { @@ -93,32 +59,24 @@ public Authentication attemptAuthentication(HttpServletRequest request, } } - private void saveRequestParameter(HttpServletRequest request, String parameterName) { - String parameterValue = request.getParameter(parameterName); - String parameterFromState = (String) microsoftOpenIdRestTemplate.getOAuth2ClientContext() - .removePreservedState(parameterName); - microsoftOpenIdRestTemplate.getOAuth2ClientContext().setPreservedState(parameterName, - parameterValue); - request.setAttribute(parameterName, parameterFromState); + @Value("${microsoft.clientId:}") + protected void setClientId(String clientId) { + this.clientId = clientId; } - private void verifyClaims(Map claims) { - int exp = (int) claims.get("exp"); - Date expireDate = new Date(exp * 1000L); - Date now = new Date(); - if (expireDate.before(now) || !claims.get("iss").equals(microsoftIssuer) - || !claims.get("aud").equals(microsoftClientId)) { - throw new RuntimeException("Invalid claims"); - } + @Value("${microsoft.issuer:}") + protected void setIssuer(String issuer) { + this.issuer = issuer; } - private RsaVerifier verifier(String kid) throws Exception { - JwkProvider provider = new UrlJwkProvider(new URL(microsoftJwkUrl)); - Jwk jwk = provider.get(kid); - return new RsaVerifier((RSAPublicKey) jwk.getPublicKey()); + @Value("${microsoft.jwkUrl:}") + protected void setJwkUrl(String jwkUrl) { + this.jwkUrl = jwkUrl; } - private void invalidateAccessToken() { - microsoftOpenIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null); + @Autowired + @Qualifier("microsoftOpenIdRestTemplate") + protected void setOpenIdRestTemplate(OAuth2RestTemplate template) { + this.openIdRestTemplate = template; } }