diff --git a/src/main/java/de/tum/cit/aet/artemis/core/config/websocket/WebsocketConfiguration.java b/src/main/java/de/tum/cit/aet/artemis/core/config/websocket/WebsocketConfiguration.java index d0c6941cc698..4a5dfdacfa24 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/config/websocket/WebsocketConfiguration.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/config/websocket/WebsocketConfiguration.java @@ -19,7 +19,6 @@ import java.util.regex.Pattern; import jakarta.annotation.Nullable; -import jakarta.servlet.http.Cookie; import jakarta.validation.constraints.NotNull; import org.slf4j.Logger; @@ -28,6 +27,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Profile; +import org.springframework.http.HttpStatusCode; import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -52,7 +52,6 @@ import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; -import org.springframework.web.util.WebUtils; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Iterators; @@ -201,9 +200,14 @@ public HandshakeInterceptor httpSessionHandshakeInterceptor() { public boolean beforeHandshake(@NotNull ServerHttpRequest request, @NotNull ServerHttpResponse response, @NotNull WebSocketHandler wsHandler, @NotNull Map attributes) { if (request instanceof ServletServerHttpRequest servletRequest) { - attributes.put(IP_ADDRESS, servletRequest.getRemoteAddress()); - Cookie jwtCookie = WebUtils.getCookie(servletRequest.getServletRequest(), JWTFilter.JWT_COOKIE_NAME); - return JWTFilter.isJwtCookieValid(tokenProvider, jwtCookie); + try { + attributes.put(IP_ADDRESS, servletRequest.getRemoteAddress()); + return JWTFilter.extractValidJwt(servletRequest.getServletRequest(), tokenProvider) != null; + } + catch (IllegalArgumentException e) { + response.setStatusCode(HttpStatusCode.valueOf(400)); + return false; + } } return false; } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/security/jwt/JWTFilter.java b/src/main/java/de/tum/cit/aet/artemis/core/security/jwt/JWTFilter.java index a7373fcd9874..ff1ddcaaf3e3 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/security/jwt/JWTFilter.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/security/jwt/JWTFilter.java @@ -2,12 +2,14 @@ import java.io.IOException; +import jakarta.annotation.Nullable; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletResponse; import jakarta.servlet.http.Cookie; import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; @@ -22,6 +24,10 @@ public class JWTFilter extends GenericFilterBean { public static final String JWT_COOKIE_NAME = "jwt"; + private static final String AUTHORIZATION_HEADER = "Authorization"; + + private static final String BEARER_PREFIX = "Bearer "; + private final TokenProvider tokenProvider; public JWTFilter(TokenProvider tokenProvider) { @@ -31,26 +37,89 @@ public JWTFilter(TokenProvider tokenProvider) { @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest; - Cookie jwtCookie = WebUtils.getCookie(httpServletRequest, JWT_COOKIE_NAME); - if (isJwtCookieValid(this.tokenProvider, jwtCookie)) { - Authentication authentication = this.tokenProvider.getAuthentication(jwtCookie.getValue()); + HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse; + String jwtToken; + try { + jwtToken = extractValidJwt(httpServletRequest, this.tokenProvider); + } + catch (IllegalArgumentException e) { + httpServletResponse.sendError(HttpServletResponse.SC_BAD_REQUEST); + return; + } + + if (jwtToken != null) { + Authentication authentication = this.tokenProvider.getAuthentication(jwtToken); SecurityContextHolder.getContext().setAuthentication(authentication); } + filterChain.doFilter(servletRequest, servletResponse); } /** - * Checks if the cookie containing the jwt is valid + * Extracts the valid jwt found in the cookie or the Authorization header * - * @param tokenProvider the artemis token provider used to generate and validate jwt's - * @param jwtCookie the cookie containing the jwt - * @return true if the jwt is valid, false if missing or invalid + * @param httpServletRequest the http request + * @param tokenProvider the Artemis token provider used to generate and validate jwt's + * @return the valid jwt or null if not found or invalid */ - public static boolean isJwtCookieValid(TokenProvider tokenProvider, Cookie jwtCookie) { + public static @Nullable String extractValidJwt(HttpServletRequest httpServletRequest, TokenProvider tokenProvider) { + var cookie = WebUtils.getCookie(httpServletRequest, JWT_COOKIE_NAME); + var authHeader = httpServletRequest.getHeader(AUTHORIZATION_HEADER); + + if (cookie == null && authHeader == null) { + return null; + } + + if (cookie != null && authHeader != null) { + // Single Method Enforcement: Only one method of authentication is allowed + throw new IllegalArgumentException("Multiple authentication methods detected: Both JWT cookie and Bearer token are present"); + } + + String jwtToken = cookie != null ? getJwtFromCookie(cookie) : getJwtFromBearer(authHeader); + + if (!isJwtValid(tokenProvider, jwtToken)) { + return null; + } + + return jwtToken; + } + + /** + * Extracts the jwt from the cookie + * + * @param jwtCookie the cookie with Key "jwt" + * @return the jwt or null if not found + */ + private static @Nullable String getJwtFromCookie(@Nullable Cookie jwtCookie) { if (jwtCookie == null) { - return false; + return null; + } + return jwtCookie.getValue(); + } + + /** + * Extracts the jwt from the Authorization header + * + * @param jwtBearer the content of the Authorization header + * @return the jwt or null if not found + */ + private static @Nullable String getJwtFromBearer(@Nullable String jwtBearer) { + if (!StringUtils.hasText(jwtBearer) || !jwtBearer.startsWith(BEARER_PREFIX)) { + return null; } - String jwt = jwtCookie.getValue(); - return StringUtils.hasText(jwt) && tokenProvider.validateTokenForAuthority(jwt); + + String token = jwtBearer.substring(BEARER_PREFIX.length()).trim(); + return StringUtils.hasText(token) ? token : null; + } + + /** + * Checks if the jwt is valid + * + * @param tokenProvider the Artemis token provider used to generate and validate jwt's + * @param jwtToken the jwt + * @return true if the jwt is valid, false if missing or invalid + */ + private static boolean isJwtValid(TokenProvider tokenProvider, @Nullable String jwtToken) { + return StringUtils.hasText(jwtToken) && tokenProvider.validateTokenForAuthority(jwtToken); } } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/security/jwt/TokenProvider.java b/src/main/java/de/tum/cit/aet/artemis/core/security/jwt/TokenProvider.java index 262ece79700d..044d897d12c7 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/security/jwt/TokenProvider.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/security/jwt/TokenProvider.java @@ -170,6 +170,11 @@ private Claims parseClaims(String authToken) { return Jwts.parser().verifyWith(key).build().parseSignedClaims(authToken).getPayload(); } + public T getClaim(String token, String claimName, Class claimType) { + Claims claims = parseClaims(token); + return claims.get(claimName, claimType); + } + public Date getExpirationDate(String authToken) { return parseClaims(authToken).getExpiration(); } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/web/open/PublicUserJwtResource.java b/src/main/java/de/tum/cit/aet/artemis/core/web/open/PublicUserJwtResource.java index 44e44a0ff87a..90020572f571 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/web/open/PublicUserJwtResource.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/web/open/PublicUserJwtResource.java @@ -2,6 +2,7 @@ import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE; +import java.util.Map; import java.util.Optional; import jakarta.servlet.ServletException; @@ -69,7 +70,7 @@ public PublicUserJwtResource(JWTCookieService jwtCookieService, AuthenticationMa */ @PostMapping("authenticate") @EnforceNothing - public ResponseEntity authorize(@Valid @RequestBody LoginVM loginVM, @RequestHeader("User-Agent") String userAgent, HttpServletResponse response) { + public ResponseEntity> authorize(@Valid @RequestBody LoginVM loginVM, @RequestHeader("User-Agent") String userAgent, HttpServletResponse response) { var username = loginVM.getUsername(); var password = loginVM.getPassword(); @@ -86,7 +87,7 @@ public ResponseEntity authorize(@Valid @RequestBody LoginVM loginVM, @Requ ResponseCookie responseCookie = jwtCookieService.buildLoginCookie(rememberMe); response.addHeader(HttpHeaders.SET_COOKIE, responseCookie.toString()); - return ResponseEntity.ok().build(); + return ResponseEntity.ok(Map.of("access_token", responseCookie.getValue())); } catch (BadCredentialsException ex) { log.warn("Wrong credentials during login for user {}", loginVM.getUsername()); diff --git a/src/test/java/de/tum/cit/aet/artemis/core/authentication/InternalAuthenticationIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/core/authentication/InternalAuthenticationIntegrationTest.java index 3a756bd65ede..99e424d35022 100644 --- a/src/test/java/de/tum/cit/aet/artemis/core/authentication/InternalAuthenticationIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/core/authentication/InternalAuthenticationIntegrationTest.java @@ -10,6 +10,7 @@ import java.time.ZonedDateTime; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -26,6 +27,9 @@ import org.springframework.security.test.context.support.WithAnonymousUser; import org.springframework.security.test.context.support.WithMockUser; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + import de.tum.cit.aet.artemis.core.connector.GitlabRequestMockProvider; import de.tum.cit.aet.artemis.core.domain.Authority; import de.tum.cit.aet.artemis.core.domain.Course; @@ -35,6 +39,7 @@ import de.tum.cit.aet.artemis.core.repository.AuthorityRepository; import de.tum.cit.aet.artemis.core.security.Role; import de.tum.cit.aet.artemis.core.security.SecurityUtils; +import de.tum.cit.aet.artemis.core.security.jwt.TokenProvider; import de.tum.cit.aet.artemis.core.service.user.PasswordService; import de.tum.cit.aet.artemis.core.util.CourseFactory; import de.tum.cit.aet.artemis.programming.test_repository.ProgrammingExerciseTestRepository; @@ -50,6 +55,9 @@ class InternalAuthenticationIntegrationTest extends AbstractSpringIntegrationJen @Autowired private PasswordService passwordService; + @Autowired + private TokenProvider tokenProvider; + @Autowired private ProgrammingExerciseTestRepository programmingExerciseRepository; @@ -223,6 +231,10 @@ void testJWTAuthentication() throws Exception { MockHttpServletResponse response = request.postWithoutResponseBody("/api/public/authenticate", loginVM, HttpStatus.OK, httpHeaders); AuthenticationIntegrationTestHelper.authenticationCookieAssertions(response.getCookie("jwt"), false); + + var responseBody = new ObjectMapper().readValue(response.getContentAsString(), new TypeReference>() { + }); + assertThat(tokenProvider.validateTokenForAuthority(responseBody.get("access_token").toString())).isTrue(); } @Test diff --git a/src/test/java/de/tum/cit/aet/artemis/core/security/jwt/JWTFilterTest.java b/src/test/java/de/tum/cit/aet/artemis/core/security/jwt/JWTFilterTest.java index 59392188c127..b50a63de8202 100644 --- a/src/test/java/de/tum/cit/aet/artemis/core/security/jwt/JWTFilterTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/core/security/jwt/JWTFilterTest.java @@ -47,7 +47,7 @@ void setup() { } @Test - void testJWTFilter() throws Exception { + void testJWTFilterCookie() throws Exception { UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("test-user", "test-password", Collections.singletonList(new SimpleGrantedAuthority(Role.STUDENT.getAuthority()))); String jwt = tokenProvider.createToken(authentication, false); @@ -61,6 +61,40 @@ void testJWTFilter() throws Exception { assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("test-user"); } + @Test + void testJWTFilterBearer() throws Exception { + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("test-user", "test-password", + Collections.singletonList(new SimpleGrantedAuthority(Role.STUDENT.getAuthority()))); + + String jwt = tokenProvider.createToken(authentication, false); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setCookies(new Cookie(JWTFilter.JWT_COOKIE_NAME, jwt)); + request.addHeader("Authorization", "Bearer " + jwt); + request.setRequestURI("/api/test"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain filterChain = new MockFilterChain(); + jwtFilter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + } + + @Test + void testJWTFilterCookieAndBearer() throws Exception { + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("test-user", "test-password", + Collections.singletonList(new SimpleGrantedAuthority(Role.STUDENT.getAuthority()))); + + String jwt = tokenProvider.createToken(authentication, false); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Bearer " + jwt); + request.setRequestURI("/api/test"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain filterChain = new MockFilterChain(); + jwtFilter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("test-user"); + } + @Test void testJWTFilterInvalidToken() throws Exception { String jwt = "wrong_jwt";