Skip to content

Commit

Permalink
Added support for JWT based authentication (#739)
Browse files Browse the repository at this point in the history
  • Loading branch information
sukalpomitra authored Dec 2, 2023
1 parent 45e00ac commit 8c5f0c7
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs-web-common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@
<groupId>org.slf4j</groupId>
<artifactId>jul-to-slf4j</artifactId>
</dependency>

<dependency>
<groupId>com.auth0</groupId>
<artifactId>java-jwt</artifactId>
<version>4.4.0</version>
</dependency>

<!-- Test dependencies -->
<dependency>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package com.sismics.util.filter;

import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.JWT;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.JWTVerifier;

import java.io.IOException;
import java.io.Reader;
import java.util.Base64;

import com.sismics.docs.core.constant.Constants;
import com.sismics.docs.core.dao.UserDao;
import com.sismics.docs.core.model.jpa.User;
import jakarta.json.Json;
import jakarta.json.JsonArray;
import jakarta.json.JsonObject;
import jakarta.json.JsonReader;
import jakarta.servlet.FilterConfig;
import jakarta.servlet.http.HttpServletRequest;
import okhttp3.Request;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.interfaces.RSAPublicKey;
import java.util.Objects;
import java.util.UUID;

import static java.util.Optional.ofNullable;

/**
* This filter is used to authenticate the user having an active session by validating a jwt token.
* The filter extracts the jwt token stored from Authorization header.
* It validates the token by calling an Identity Broker like KeyCloak.
* If validated, the user is retrieved, and the filter injects a UserPrincipal into the request attribute.
*
* @author smitra
*/
public class JwtBasedSecurityFilter extends SecurityFilter {
private static final Logger log = LoggerFactory.getLogger(JwtBasedSecurityFilter.class);
private static final okhttp3.OkHttpClient client = new okhttp3.OkHttpClient();
/**
* Name of the header used to store the authentication token.
*/
public static final String HEADER_NAME = "Authorization";
/**
* True if this authentication method is enabled.
*/
private boolean enabled;

@Override
public void init(FilterConfig filterConfig) {
enabled = Boolean.parseBoolean(filterConfig.getInitParameter("enabled"))
|| Boolean.parseBoolean(System.getProperty("docs.jwt_authentication"));
}

@Override
protected User authenticate(final HttpServletRequest request) {
if (!enabled) {
return null;
}
log.info("Jwt authentication started");
User user = null;
String token = extractAuthToken(request).replace("Bearer ", "");
DecodedJWT jwt = JWT.decode(token);
if (verifyJwt(jwt, token)) {
String email = jwt.getClaim("preferred_username").toString();
UserDao userDao = new UserDao();
user = userDao.getActiveByUsername(email);
if (user == null) {
user = new User();
user.setRoleId(Constants.DEFAULT_USER_ROLE);
user.setUsername(email);
user.setEmail(email);
user.setStorageQuota(Long.parseLong(ofNullable(System.getenv(Constants.GLOBAL_QUOTA_ENV))
.orElse("1073741824")));
user.setPassword(UUID.randomUUID().toString());
try {
userDao.create(user, email);
log.info("user created");
} catch (Exception e) {
log.info("Error:" + e.getMessage());
return null;
}
}
}
return user;
}

private boolean verifyJwt(final DecodedJWT jwt, final String token) {

try {
buildJWTVerifier(jwt).verify(token);
// if token is valid no exception will be thrown
log.info("Valid TOKEN");
return Boolean.TRUE;
} catch (CertificateException e) {
//if CertificateException comes from buildJWTVerifier()
log.info("InValid TOKEN: " + e.getMessage());
return Boolean.FALSE;
} catch (JWTVerificationException e) {
// if JWT Token in invalid
log.info("InValid TOKEN: " + e.getMessage() );
return Boolean.FALSE;
} catch (Exception e) {
// If any other exception comes
log.info("InValid TOKEN, Exception Occurred: " + e.getMessage());
return Boolean.FALSE;
}
}

private String extractAuthToken(final HttpServletRequest request) {
return ofNullable(request.getHeader("Authorization")).orElse("");
}

private RSAPublicKey getPublicKey(DecodedJWT jwt) {
String jwtIssuerCerts = jwt.getIssuer() + "/protocol/openid-connect/certs";
String publicKey = "";
RSAPublicKey rsaPublicKey = null;
Request request = new Request.Builder()
.url(jwtIssuerCerts)
.get()
.build();
try (Response response = client.newCall(request).execute()) {
log.info("Successfully called the jwt issuer at: " + jwtIssuerCerts + " - " + response.code());
assert response.body() != null;
if (response.isSuccessful()) {
try (Reader reader = response.body().charStream()) {
try (JsonReader jsonReader = Json.createReader(reader)) {
JsonObject jwks = jsonReader.readObject();
JsonArray keys = jwks.getJsonArray("keys");
publicKey = keys.stream().filter(key -> Objects.equals(key.asJsonObject().getString("kid"),
jwt.getKeyId()))
.findFirst()
.map(k -> k.asJsonObject().getJsonArray("x5c").getString(0))
.orElse("");
var decode = Base64.getDecoder().decode(publicKey);
var certificate = CertificateFactory.getInstance("X.509")
.generateCertificate(new ByteArrayInputStream(decode));
rsaPublicKey = (RSAPublicKey) certificate.getPublicKey();
}
}
}
} catch (IOException e) {
log.error("Error calling the jwt issuer at: " + jwtIssuerCerts, e);
} catch (CertificateException e) {
log.error("Error in getting the certificate: ", e);
}
return rsaPublicKey;
}

private JWTVerifier buildJWTVerifier(DecodedJWT jwt) throws CertificateException {
var algo = Algorithm.RSA256(getPublicKey(jwt), null);
return JWT.require(algo).build();
}
}
15 changes: 15 additions & 0 deletions docs-web/src/main/webapp/WEB-INF/web.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@
<async-supported>true</async-supported>
</filter>

<filter>
<filter-name>jwtBasedSecurityFilter</filter-name>
<filter-class>com.sismics.util.filter.JwtBasedSecurityFilter</filter-class>
<async-supported>true</async-supported>
<init-param>
<param-name>enabled</param-name>
<param-value>false</param-value>
</init-param>
</filter>

<filter>
<filter-name>headerBasedSecurityFilter</filter-name>
<filter-class>com.sismics.util.filter.HeaderBasedSecurityFilter</filter-class>
Expand All @@ -59,6 +69,11 @@
<url-pattern>/api/*</url-pattern>
</filter-mapping>

<filter-mapping>
<filter-name>jwtBasedSecurityFilter</filter-name>
<url-pattern>/api/*</url-pattern>
</filter-mapping>

<filter-mapping>
<filter-name>headerBasedSecurityFilter</filter-name>
<url-pattern>/api/*</url-pattern>
Expand Down

0 comments on commit 8c5f0c7

Please sign in to comment.