Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(Social Login): Extract common code to abstract class #267

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package org.wise.portal.presentation.web.filters;

import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.util.Date;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.jwt.crypto.sign.RsaVerifier;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.wise.portal.service.authentication.UserDetailsService;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkProvider;
import com.auth0.jwk.UrlJwkProvider;

public abstract class AbstractOpenIdConnectFilter extends AbstractAuthenticationProcessingFilter {

protected String clientId;
protected String issuer;
protected String jwkUrl;
protected OAuth2RestTemplate openIdRestTemplate;

@Autowired
protected UserDetailsService userDetailsService;

protected AbstractOpenIdConnectFilter(String defaultFilterProcessesUrl) {
super(defaultFilterProcessesUrl);
setAuthenticationManager(new NoopAuthenticationManager());
}

protected OAuth2AccessToken getAccessToken() {
OAuth2AccessToken accessToken;
try {
accessToken = openIdRestTemplate.getAccessToken();
} catch (final OAuth2Exception e) {
throw new BadCredentialsException("Could not obtain access token", e);
}
return accessToken;
}

protected void saveRequestParams(HttpServletRequest request) {
saveRequestParameter(request, "accessCode");
saveRequestParameter(request, "redirectUrl");
}

protected void saveRequestParameter(HttpServletRequest request, String parameterName) {
String parameterValue = request.getParameter(parameterName);
String parameterFromState = (String) openIdRestTemplate.getOAuth2ClientContext()
.removePreservedState(parameterName);
openIdRestTemplate.getOAuth2ClientContext().setPreservedState(parameterName, parameterValue);
request.setAttribute(parameterName, parameterFromState);
}

protected void verifyClaims(Map claims) {
int exp = (int) claims.get("exp");
Date expireDate = new Date(exp * 1000L);
Date now = new Date();
if (expireDate.before(now) || !claims.get("iss").equals(issuer)
|| !claims.get("aud").equals(clientId)) {
throw new RuntimeException("Invalid claims");
}
}

protected RsaVerifier verifier(String kid) throws Exception {
JwkProvider provider = new UrlJwkProvider(new URL(jwkUrl));
Jwk jwk = provider.get(kid);
return new RsaVerifier((RSAPublicKey) jwk.getPublicKey());
}

protected void invalidateAccessToken() {
openIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null);
}

protected abstract void setClientId(String clientId);

protected abstract void setIssuer(String issuer);

protected abstract void setJwkUrl(String jwkUrl);

protected abstract void setOpenIdRestTemplate(OAuth2RestTemplate template);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,87 +24,52 @@
package org.wise.portal.presentation.web.filters;

import java.io.IOException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.util.Date;
import java.util.Map;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkProvider;
import com.auth0.jwk.UrlJwkProvider;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtHelper;
import org.springframework.security.jwt.crypto.sign.RsaVerifier;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.wise.portal.domain.authentication.MutableUserDetails;
import org.wise.portal.service.authentication.UserDetailsService;
import org.wise.portal.service.session.SessionService;

public class GoogleOpenIdConnectFilter extends AbstractAuthenticationProcessingFilter {

@Value("${google.clientId:}")
private String googleClientId;

@Value("${google.issuer:}")
private String googleIssuer;

@Value("${google.jwkUrl:}")
private String googleJwkUrl;

@Autowired
@Qualifier("googleOpenIdRestTemplate")
private OAuth2RestTemplate googleOpenIdRestTemplate;

@Autowired
private UserDetailsService userDetailsService;
public class GoogleOpenIdConnectFilter extends AbstractOpenIdConnectFilter {

@Autowired
protected SessionService sessionService;

public GoogleOpenIdConnectFilter(String defaultFilterProcessesUrl) {
super(defaultFilterProcessesUrl);
setAuthenticationManager(new NoopAuthenticationManager());
}

@Override
public Authentication attemptAuthentication(HttpServletRequest request,
HttpServletResponse response) throws AuthenticationException, IOException {
saveRequestParameter(request, "accessCode");
saveRequestParameter(request, "redirectUrl");
OAuth2AccessToken accessToken;
try {
accessToken = googleOpenIdRestTemplate.getAccessToken();
} catch (final OAuth2Exception e) {
throw new BadCredentialsException("Could not obtain access token", e);
}
saveRequestParams(request);
OAuth2AccessToken accessToken = getAccessToken();
try {
final String idToken = accessToken.getAdditionalInformation().get("id_token").toString();
String kid = JwtHelper.headers(idToken).get("kid");
final Jwt tokenDecoded = JwtHelper.decodeAndVerify(idToken, verifier(kid));
final Map<String, String> authInfo = new ObjectMapper().readValue(tokenDecoded.getClaims(),
Map.class);
verifyClaims(authInfo);
String googleUserId = authInfo.get("sub");
final UserDetails user = userDetailsService.loadUserByGoogleUserId(googleUserId);
final UserDetails user = userDetailsService.loadUserByGoogleUserId(authInfo.get("sub"));
invalidateAccesToken();
if (user != null) {
return new UsernamePasswordAuthenticationToken(user, null, user.getAuthorities());
Expand All @@ -116,47 +81,8 @@ public Authentication attemptAuthentication(HttpServletRequest request,
}
}

private void saveRequestParameter(HttpServletRequest request, String parameterName) {
String parameterValue = request.getParameter(parameterName);
String parameterFromState = (String) googleOpenIdRestTemplate.getOAuth2ClientContext()
.removePreservedState(parameterName);
googleOpenIdRestTemplate.getOAuth2ClientContext().setPreservedState(parameterName,
parameterValue);
request.setAttribute(parameterName, parameterFromState);
}

private void invalidateAccesToken() {
googleOpenIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null);
}

public void verifyClaims(Map claims) {
int exp = (int) claims.get("exp");
Date expireDate = new Date(exp * 1000L);
Date now = new Date();
if (expireDate.before(now) || !claims.get("iss").equals(googleIssuer)
|| !claims.get("aud").equals(googleClientId)) {
throw new RuntimeException("Invalid claims");
}
}

private RsaVerifier verifier(String kid) throws Exception {
JwkProvider provider = new UrlJwkProvider(new URL(googleJwkUrl));
Jwk jwk = provider.get(kid);
return new RsaVerifier((RSAPublicKey) jwk.getPublicKey());
}

public void setRestTemplate(OAuth2RestTemplate restTemplate2) {
googleOpenIdRestTemplate = restTemplate2;
}

private static class NoopAuthenticationManager implements AuthenticationManager {

@Override
public Authentication authenticate(Authentication authentication)
throws AuthenticationException {
throw new UnsupportedOperationException(
"No authentication should be done with this AuthenticationManager");
}
openIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null);
}

@Override
Expand All @@ -168,4 +94,24 @@ protected void successfulAuthentication(HttpServletRequest request, HttpServletR
super.successfulAuthentication(request, response, chain, authentication);
}

@Value("${google.clientId:}")
protected void setClientId(String clientId) {
this.clientId = clientId;
}

@Value("${google.issuer:}")
protected void setIssuer(String issuer) {
this.issuer = issuer;
}

@Value("${google.jwkUrl:}")
protected void setJwkUrl(String jwkUrl) {
this.jwkUrl = jwkUrl;
}

@Autowired
@Qualifier("googleOpenIdRestTemplate")
protected void setOpenIdRestTemplate(OAuth2RestTemplate template) {
this.openIdRestTemplate = template;
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package org.wise.portal.presentation.web.filters;

import java.io.IOException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.util.Date;
import java.util.Map;

import javax.servlet.ServletException;
Expand All @@ -13,59 +10,29 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtHelper;
import org.springframework.security.jwt.crypto.sign.RsaVerifier;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.wise.portal.presentation.web.exception.MicrosoftUserNotFoundException;
import org.wise.portal.service.authentication.UserDetailsService;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkProvider;
import com.auth0.jwk.UrlJwkProvider;
import com.fasterxml.jackson.databind.ObjectMapper;

public class MicrosoftOpenIdConnectFilter extends AbstractAuthenticationProcessingFilter {

@Value("${microsoft.clientId:}")
private String microsoftClientId;

@Value("${microsoft.issuer:}")
private String microsoftIssuer;

@Value("${microsoft.jwkUrl:}")
private String microsoftJwkUrl;

@Autowired
@Qualifier("microsoftOpenIdRestTemplate")
private OAuth2RestTemplate microsoftOpenIdRestTemplate;

@Autowired
private UserDetailsService userDetailsService;
public class MicrosoftOpenIdConnectFilter extends AbstractOpenIdConnectFilter {

public MicrosoftOpenIdConnectFilter(String defaultFilterProcessesUrl) {
super(defaultFilterProcessesUrl);
setAuthenticationManager(new NoopAuthenticationManager());
}

@Override
public Authentication attemptAuthentication(HttpServletRequest request,
HttpServletResponse response) throws AuthenticationException, IOException, ServletException {
saveRequestParameter(request, "redirectUrl");
OAuth2AccessToken accessToken;
try {
accessToken = microsoftOpenIdRestTemplate.getAccessToken();
} catch (final OAuth2Exception e) {
throw new BadCredentialsException("Could not obtain access token", e);
}
saveRequestParams(request);
OAuth2AccessToken accessToken = getAccessToken();
final String idToken = accessToken.getAdditionalInformation().get("id_token").toString();
String kid = JwtHelper.headers(idToken).get("kid");
Jwt tokenDecoded = null;
Expand All @@ -78,8 +45,7 @@ public Authentication attemptAuthentication(HttpServletRequest request,
final Map<String, String> authInfo = new ObjectMapper().readValue(tokenDecoded.getClaims(),
Map.class);
verifyClaims(authInfo);
String microsoftUserId = authInfo.get("sub");
final UserDetails user = userDetailsService.loadUserByMicrosoftUserId(microsoftUserId);
final UserDetails user = userDetailsService.loadUserByMicrosoftUserId(authInfo.get("sub"));
invalidateAccessToken();
if (user != null) {
if (request.getAttribute("redirectUrl").toString().contains("join")) {
Expand All @@ -93,32 +59,24 @@ public Authentication attemptAuthentication(HttpServletRequest request,
}
}

private void saveRequestParameter(HttpServletRequest request, String parameterName) {
String parameterValue = request.getParameter(parameterName);
String parameterFromState = (String) microsoftOpenIdRestTemplate.getOAuth2ClientContext()
.removePreservedState(parameterName);
microsoftOpenIdRestTemplate.getOAuth2ClientContext().setPreservedState(parameterName,
parameterValue);
request.setAttribute(parameterName, parameterFromState);
@Value("${microsoft.clientId:}")
protected void setClientId(String clientId) {
this.clientId = clientId;
}

private void verifyClaims(Map claims) {
int exp = (int) claims.get("exp");
Date expireDate = new Date(exp * 1000L);
Date now = new Date();
if (expireDate.before(now) || !claims.get("iss").equals(microsoftIssuer)
|| !claims.get("aud").equals(microsoftClientId)) {
throw new RuntimeException("Invalid claims");
}
@Value("${microsoft.issuer:}")
protected void setIssuer(String issuer) {
this.issuer = issuer;
}

private RsaVerifier verifier(String kid) throws Exception {
JwkProvider provider = new UrlJwkProvider(new URL(microsoftJwkUrl));
Jwk jwk = provider.get(kid);
return new RsaVerifier((RSAPublicKey) jwk.getPublicKey());
@Value("${microsoft.jwkUrl:}")
protected void setJwkUrl(String jwkUrl) {
this.jwkUrl = jwkUrl;
}

private void invalidateAccessToken() {
microsoftOpenIdRestTemplate.getOAuth2ClientContext().setAccessToken((OAuth2AccessToken) null);
@Autowired
@Qualifier("microsoftOpenIdRestTemplate")
protected void setOpenIdRestTemplate(OAuth2RestTemplate template) {
this.openIdRestTemplate = template;
}
}
Loading