Skip to content

Commit

Permalink
Refactor Google/Microsoft OpenIdConnectFilter. Extract common code to…
Browse files Browse the repository at this point in the history
… Abstract class
  • Loading branch information
hirokiterashima committed Apr 8, 2024
1 parent 53e6e5c commit a684c01
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 138 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package org.wise.portal.presentation.web.filters;

import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.util.Date;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.jwt.crypto.sign.RsaVerifier;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.wise.portal.service.authentication.UserDetailsService;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkProvider;
import com.auth0.jwk.UrlJwkProvider;

public abstract class AbstractOpenIdConnectFilter extends AbstractAuthenticationProcessingFilter {

protected String clientId;
protected String issuer;
protected String jwkUrl;
protected OAuth2RestTemplate openIdRestTemplate;

@Autowired
protected UserDetailsService userDetailsService;

protected AbstractOpenIdConnectFilter(String defaultFilterProcessesUrl) {
super(defaultFilterProcessesUrl);
setAuthenticationManager(new NoopAuthenticationManager());
}

protected OAuth2AccessToken getAccessToken() {
OAuth2AccessToken accessToken;
try {
accessToken = openIdRestTemplate.getAccessToken();
} catch (final OAuth2Exception e) {
throw new BadCredentialsException("Could not obtain access token", e);
}
return accessToken;
}

protected void saveRequestParams(HttpServletRequest request) {
saveRequestParameter(request, "accessCode");
saveRequestParameter(request, "redirectUrl");
}

protected void saveRequestParameter(HttpServletRequest request, String parameterName) {
String parameterValue = request.getParameter(parameterName);
String parameterFromState = (String) openIdRestTemplate.getOAuth2ClientContext()
.removePreservedState(parameterName);
openIdRestTemplate.getOAuth2ClientContext().setPreservedState(parameterName, parameterValue);
request.setAttribute(parameterName, parameterFromState);
}

protected void verifyClaims(Map claims) {
int exp = (int) claims.get("exp");
Date expireDate = new Date(exp * 1000L);
Date now = new Date();
if (expireDate.before(now) || !claims.get("iss").equals(issuer)
|| !claims.get("aud").equals(clientId)) {
throw new RuntimeException("Invalid claims");
}
}

protected RsaVerifier verifier(String kid) throws Exception {
JwkProvider provider = new UrlJwkProvider(new URL(jwkUrl));
Jwk jwk = provider.get(kid);
return new RsaVerifier((RSAPublicKey) jwk.getPublicKey());
}

protected void invalidateAccessToken() {
openIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null);
}

protected abstract void setClientId(String clientId);

protected abstract void setIssuer(String issuer);

protected abstract void setJwkUrl(String jwkUrl);

protected abstract void setOpenIdRestTemplate(OAuth2RestTemplate template);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,87 +24,52 @@
package org.wise.portal.presentation.web.filters;

import java.io.IOException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.util.Date;
import java.util.Map;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkProvider;
import com.auth0.jwk.UrlJwkProvider;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtHelper;
import org.springframework.security.jwt.crypto.sign.RsaVerifier;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.wise.portal.domain.authentication.MutableUserDetails;
import org.wise.portal.service.authentication.UserDetailsService;
import org.wise.portal.service.session.SessionService;

public class GoogleOpenIdConnectFilter extends AbstractAuthenticationProcessingFilter {

@Value("${google.clientId:}")
private String googleClientId;

@Value("${google.issuer:}")
private String googleIssuer;

@Value("${google.jwkUrl:}")
private String googleJwkUrl;

@Autowired
@Qualifier("googleOpenIdRestTemplate")
private OAuth2RestTemplate googleOpenIdRestTemplate;

@Autowired
private UserDetailsService userDetailsService;
public class GoogleOpenIdConnectFilter extends AbstractOpenIdConnectFilter {

@Autowired
protected SessionService sessionService;

public GoogleOpenIdConnectFilter(String defaultFilterProcessesUrl) {
super(defaultFilterProcessesUrl);
setAuthenticationManager(new NoopAuthenticationManager());
}

@Override
public Authentication attemptAuthentication(HttpServletRequest request,
HttpServletResponse response) throws AuthenticationException, IOException {
saveRequestParameter(request, "accessCode");
saveRequestParameter(request, "redirectUrl");
OAuth2AccessToken accessToken;
try {
accessToken = googleOpenIdRestTemplate.getAccessToken();
} catch (final OAuth2Exception e) {
throw new BadCredentialsException("Could not obtain access token", e);
}
saveRequestParams(request);
OAuth2AccessToken accessToken = getAccessToken();
try {
final String idToken = accessToken.getAdditionalInformation().get("id_token").toString();
String kid = JwtHelper.headers(idToken).get("kid");
final Jwt tokenDecoded = JwtHelper.decodeAndVerify(idToken, verifier(kid));
final Map<String, String> authInfo = new ObjectMapper().readValue(tokenDecoded.getClaims(),
Map.class);
verifyClaims(authInfo);
String googleUserId = authInfo.get("sub");
final UserDetails user = userDetailsService.loadUserByGoogleUserId(googleUserId);
final UserDetails user = userDetailsService.loadUserByGoogleUserId(authInfo.get("sub"));
invalidateAccesToken();
if (user != null) {
return new UsernamePasswordAuthenticationToken(user, null, user.getAuthorities());
Expand All @@ -116,47 +81,8 @@ public Authentication attemptAuthentication(HttpServletRequest request,
}
}

private void saveRequestParameter(HttpServletRequest request, String parameterName) {
String parameterValue = request.getParameter(parameterName);
String parameterFromState = (String) googleOpenIdRestTemplate.getOAuth2ClientContext()
.removePreservedState(parameterName);
googleOpenIdRestTemplate.getOAuth2ClientContext().setPreservedState(parameterName,
parameterValue);
request.setAttribute(parameterName, parameterFromState);
}

private void invalidateAccesToken() {
googleOpenIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null);
}

public void verifyClaims(Map claims) {
int exp = (int) claims.get("exp");
Date expireDate = new Date(exp * 1000L);
Date now = new Date();
if (expireDate.before(now) || !claims.get("iss").equals(googleIssuer)
|| !claims.get("aud").equals(googleClientId)) {
throw new RuntimeException("Invalid claims");
}
}

private RsaVerifier verifier(String kid) throws Exception {
JwkProvider provider = new UrlJwkProvider(new URL(googleJwkUrl));
Jwk jwk = provider.get(kid);
return new RsaVerifier((RSAPublicKey) jwk.getPublicKey());
}

public void setRestTemplate(OAuth2RestTemplate restTemplate2) {
googleOpenIdRestTemplate = restTemplate2;
}

private static class NoopAuthenticationManager implements AuthenticationManager {

@Override
public Authentication authenticate(Authentication authentication)
throws AuthenticationException {
throw new UnsupportedOperationException(
"No authentication should be done with this AuthenticationManager");
}
openIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null);
}

@Override
Expand All @@ -168,4 +94,24 @@ protected void successfulAuthentication(HttpServletRequest request, HttpServletR
super.successfulAuthentication(request, response, chain, authentication);
}

@Value("${google.clientId:}")
protected void setClientId(String clientId) {
this.clientId = clientId;
}

@Value("${google.issuer:}")
protected void setIssuer(String issuer) {
this.issuer = issuer;
}

@Value("${google.jwkUrl:}")
protected void setJwkUrl(String jwkUrl) {
this.jwkUrl = jwkUrl;
}

@Autowired
@Qualifier("googleOpenIdRestTemplate")
protected void setOpenIdRestTemplate(OAuth2RestTemplate template) {
this.openIdRestTemplate = template;
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package org.wise.portal.presentation.web.filters;

import java.io.IOException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.util.Date;
import java.util.Map;

import javax.servlet.ServletException;
Expand All @@ -13,59 +10,29 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtHelper;
import org.springframework.security.jwt.crypto.sign.RsaVerifier;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.wise.portal.presentation.web.exception.MicrosoftUserNotFoundException;
import org.wise.portal.service.authentication.UserDetailsService;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkProvider;
import com.auth0.jwk.UrlJwkProvider;
import com.fasterxml.jackson.databind.ObjectMapper;

public class MicrosoftOpenIdConnectFilter extends AbstractAuthenticationProcessingFilter {

@Value("${microsoft.clientId:}")
private String microsoftClientId;

@Value("${microsoft.issuer:}")
private String microsoftIssuer;

@Value("${microsoft.jwkUrl:}")
private String microsoftJwkUrl;

@Autowired
@Qualifier("microsoftOpenIdRestTemplate")
private OAuth2RestTemplate microsoftOpenIdRestTemplate;

@Autowired
private UserDetailsService userDetailsService;
public class MicrosoftOpenIdConnectFilter extends AbstractOpenIdConnectFilter {

public MicrosoftOpenIdConnectFilter(String defaultFilterProcessesUrl) {
super(defaultFilterProcessesUrl);
setAuthenticationManager(new NoopAuthenticationManager());
}

@Override
public Authentication attemptAuthentication(HttpServletRequest request,
HttpServletResponse response) throws AuthenticationException, IOException, ServletException {
saveRequestParameter(request, "redirectUrl");
OAuth2AccessToken accessToken;
try {
accessToken = microsoftOpenIdRestTemplate.getAccessToken();
} catch (final OAuth2Exception e) {
throw new BadCredentialsException("Could not obtain access token", e);
}
saveRequestParams(request);
OAuth2AccessToken accessToken = getAccessToken();
final String idToken = accessToken.getAdditionalInformation().get("id_token").toString();
String kid = JwtHelper.headers(idToken).get("kid");
Jwt tokenDecoded = null;
Expand All @@ -78,8 +45,7 @@ public Authentication attemptAuthentication(HttpServletRequest request,
final Map<String, String> authInfo = new ObjectMapper().readValue(tokenDecoded.getClaims(),
Map.class);
verifyClaims(authInfo);
String microsoftUserId = authInfo.get("sub");
final UserDetails user = userDetailsService.loadUserByMicrosoftUserId(microsoftUserId);
final UserDetails user = userDetailsService.loadUserByMicrosoftUserId(authInfo.get("sub"));
invalidateAccessToken();
if (user != null) {
if (request.getAttribute("redirectUrl").toString().contains("join")) {
Expand All @@ -93,32 +59,24 @@ public Authentication attemptAuthentication(HttpServletRequest request,
}
}

private void saveRequestParameter(HttpServletRequest request, String parameterName) {
String parameterValue = request.getParameter(parameterName);
String parameterFromState = (String) microsoftOpenIdRestTemplate.getOAuth2ClientContext()
.removePreservedState(parameterName);
microsoftOpenIdRestTemplate.getOAuth2ClientContext().setPreservedState(parameterName,
parameterValue);
request.setAttribute(parameterName, parameterFromState);
@Value("${microsoft.clientId:}")
protected void setClientId(String clientId) {
this.clientId = clientId;
}

private void verifyClaims(Map claims) {
int exp = (int) claims.get("exp");
Date expireDate = new Date(exp * 1000L);
Date now = new Date();
if (expireDate.before(now) || !claims.get("iss").equals(microsoftIssuer)
|| !claims.get("aud").equals(microsoftClientId)) {
throw new RuntimeException("Invalid claims");
}
@Value("${microsoft.issuer:}")
protected void setIssuer(String issuer) {
this.issuer = issuer;
}

private RsaVerifier verifier(String kid) throws Exception {
JwkProvider provider = new UrlJwkProvider(new URL(microsoftJwkUrl));
Jwk jwk = provider.get(kid);
return new RsaVerifier((RSAPublicKey) jwk.getPublicKey());
@Value("${microsoft.jwkUrl:}")
protected void setJwkUrl(String jwkUrl) {
this.jwkUrl = jwkUrl;
}

private void invalidateAccessToken() {
microsoftOpenIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null);
@Autowired
@Qualifier("microsoftOpenIdRestTemplate")
protected void setOpenIdRestTemplate(OAuth2RestTemplate template) {
this.openIdRestTemplate = template;
}
}

0 comments on commit a684c01

Please sign in to comment.