Skip to content

Commit

Permalink
Use Nimbus OAuth SDK for OIDC
Browse files Browse the repository at this point in the history
  • Loading branch information
oneonestar authored and ebyhr committed May 20, 2024
1 parent b97a4a1 commit c052bb2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 226 deletions.
13 changes: 13 additions & 0 deletions gateway-ha/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@
<version>${dep.guice.version}</version>
</dependency>

<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
<version>9.37.3</version>
</dependency>

<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>oauth2-oidc-sdk</artifactId>
<version>11.10.1</version>
<classifier>jdk11</classifier>
</dependency>

<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,35 @@
import com.auth0.jwt.JWT;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.Scope;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.TokenResponse;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.openid.connect.sdk.AuthenticationRequest;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser;
import io.airlift.log.Logger;
import io.trino.gateway.ha.config.OAuthConfiguration;
import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.client.ClientBuilder;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.Form;
import jakarta.ws.rs.core.Response;

import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

import static java.lang.String.format;
import static com.nimbusds.oauth2.sdk.ResponseType.CODE;
import static jakarta.ws.rs.core.Response.Status.FOUND;
import static jakarta.ws.rs.core.Response.Status.UNAUTHORIZED;

public class LbOAuthManager
{
Expand Down Expand Up @@ -72,35 +79,30 @@ public String getUserIdField()
*/
public Response exchangeCodeForToken(String code, String redirectLocation)
{
String tokenEndpoint = oauthConfig.getTokenEndpoint().toString();
String clientId = oauthConfig.getClientId();
String clientSecret = oauthConfig.getClientSecret();
String redirectUri = oauthConfig.getRedirectUrl().toString();
Optional<URI> redirectWebUrl = oauthConfig.getRedirectWebUrl();
Client oauthClient = ClientBuilder.newBuilder().build();
TokenRequest tokenRequest = new TokenRequest(
oauthConfig.getTokenEndpoint(),
new ClientSecretBasic(new ClientID(oauthConfig.getClientId()), new Secret(oauthConfig.getClientSecret())),
new AuthorizationCodeGrant(new AuthorizationCode(code), oauthConfig.getRedirectUrl()));

Form form = new Form().param("grant_type", "authorization_code")
.param("client_id", clientId)
.param("client_secret", clientSecret)
.param("code", code)
.param("redirect_uri", redirectUri);

Response tokenResponse = oauthClient
.target(tokenEndpoint)
.request()
.post(Entity.form(form));

if (tokenResponse.getStatusInfo().getFamily() != Response.Status.Family.SUCCESSFUL) {
String message = format("token response failed with code %d - %s", tokenResponse.getStatus(), tokenResponse.readEntity(String.class));
log.error(message);
return Response.status(500).entity(message).build();
TokenResponse tokenResponse;
try {
tokenResponse = OIDCTokenResponseParser.parse(tokenRequest.toHTTPRequest().send());
}
catch (ParseException | IOException e) {
log.error("Failed to parse token response: %s", e.getMessage());
return Response.status(UNAUTHORIZED).build();
}

OidcTokens tokens = tokenResponse.readEntity(OidcTokens.class);
if (!tokenResponse.indicatesSuccess()) {
HTTPResponse httpResponse = tokenResponse.toErrorResponse().toHTTPResponse();
log.error("token response failed with code %d - %s", httpResponse.getStatusCode(), httpResponse.getBody());
return Response.status(UNAUTHORIZED).build();
}

return Response.status(302)
.location(redirectWebUrl.orElse(URI.create(redirectLocation)))
.cookie(SessionCookie.getTokenCookie(tokens.getIdToken()))
OIDCTokenResponse successResponse = (OIDCTokenResponse) tokenResponse.toSuccessResponse();
return Response.status(FOUND)
.location(oauthConfig.getRedirectWebUrl().orElse(URI.create(redirectLocation)))
.cookie(SessionCookie.getTokenCookie(successResponse.getOIDCTokens().getIDToken().serialize()))
.build();
}

Expand All @@ -111,13 +113,14 @@ public Response exchangeCodeForToken(String code, String redirectLocation)
*/
public String getAuthorizationCode()
{
String authorizationEndpoint = oauthConfig.getAuthorizationEndpoint().toString();
String clientId = oauthConfig.getClientId();
String redirectUrl = oauthConfig.getRedirectUrl().toString();
String scopes = String.join("+", oauthConfig.getScopes());
return format(
"%s?client_id=%s&response_type=code&redirect_uri=%s&scope=%s",
authorizationEndpoint, clientId, redirectUrl, scopes);
AuthenticationRequest request = new AuthenticationRequest.Builder(
CODE,
new Scope(oauthConfig.getScopes().toArray(String[]::new)),
new ClientID(oauthConfig.getClientId()),
oauthConfig.getRedirectUrl())
.endpointURI(oauthConfig.getAuthorizationEndpoint())
.build();
return request.toURI().toString();
}

/**
Expand Down Expand Up @@ -161,135 +164,4 @@ public List<String> processPagePermissions(List<String> roles)
.flatMap(role -> Stream.of(pagePermissions.get(role).split("_")))
.distinct().toList();
}

@JsonIgnoreProperties(ignoreUnknown = true)
static final class OidcTokens
{
private final String accessToken;
private final String idToken;
private final String scope;
private final String refreshToken;
private final String tokenType;
private final String expiresIn;

@JsonCreator
public OidcTokens(@JsonProperty("id_token") String idToken,
@JsonProperty("access_token") String accessToken,
@JsonProperty("refresh_token") String refreshToken,
@JsonProperty("token_type") String tokenType,
@JsonProperty("expires_in") String expiresIn,
@JsonProperty("scope") String scope)
{
this.accessToken = accessToken;
this.idToken = idToken;
this.tokenType = tokenType;
this.expiresIn = expiresIn;
this.scope = scope;
this.refreshToken = refreshToken;
}

@JsonProperty
public String getAccessToken()
{
return this.accessToken;
}

@JsonProperty
public String getIdToken()
{
return this.idToken;
}

@JsonProperty
public String getScope()
{
return this.scope;
}

@JsonProperty
public String getRefreshToken()
{
return this.refreshToken;
}

@JsonProperty
public String getTokenType()
{
return this.tokenType;
}

@JsonProperty
public String getExpiresIn()
{
return this.expiresIn;
}

@Override
public boolean equals(final Object o)
{
if (o == this) {
return true;
}
if (!(o instanceof OidcTokens other)) {
return false;
}
final Object accessToken = this.getAccessToken();
final Object otherAccessToken = other.getAccessToken();
if (!Objects.equals(accessToken, otherAccessToken)) {
return false;
}
final Object idToken = this.getIdToken();
final Object otherIdToken = other.getIdToken();
if (!Objects.equals(idToken, otherIdToken)) {
return false;
}
final Object scope = this.getScope();
final Object otherScope = other.getScope();
if (!Objects.equals(scope, otherScope)) {
return false;
}
final Object refreshToken = this.getRefreshToken();
final Object otherRefreshToken = other.getRefreshToken();
if (!Objects.equals(refreshToken, otherRefreshToken)) {
return false;
}
final Object tokenType = this.getTokenType();
final Object otherTokenType = other.getTokenType();
if (!Objects.equals(tokenType, otherTokenType)) {
return false;
}
final Object expiresIn = this.getExpiresIn();
final Object otherExpiresIn = other.getExpiresIn();
return Objects.equals(expiresIn, otherExpiresIn);
}

@Override
public int hashCode()
{
final int prime = 59;
int result = 1;
final Object accessToken = this.getAccessToken();
result = result * prime + (accessToken == null ? 43 : accessToken.hashCode());
final Object idToken = this.getIdToken();
result = result * prime + (idToken == null ? 43 : idToken.hashCode());
final Object scope = this.getScope();
result = result * prime + (scope == null ? 43 : scope.hashCode());
final Object refreshToken = this.getRefreshToken();
result = result * prime + (refreshToken == null ? 43 : refreshToken.hashCode());
final Object tokenType = this.getTokenType();
result = result * prime + (tokenType == null ? 43 : tokenType.hashCode());
final Object expiresIn = this.getExpiresIn();
result = result * prime + (expiresIn == null ? 43 : expiresIn.hashCode());
return result;
}

@Override
public String toString()
{
return "LbOAuthManager.OidcTokens(accessToken=" + this.getAccessToken() +
", idToken=" + this.getIdToken() + ", scope=" + this.getScope() +
", refreshToken=" + this.getRefreshToken() + ", tokenType=" + this.getTokenType() +
", expiresIn=" + this.getExpiresIn() + ")";
}
}
}

This file was deleted.

0 comments on commit c052bb2

Please sign in to comment.