From 79214625ab13ed864f3e5d1ddea1f4410c1d8a98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Ka=C5=88ka?= Date: Sat, 30 Nov 2024 12:53:17 +0100 Subject: [PATCH 1/3] [Bug #303] Silence JWT exception from websocket connection and match log level with HTTP authentication exception logging --- .../handler/StompExceptionHandler.java | 44 +----------- .../handler/WebSocketExceptionHandler.java | 69 ++++++++++++++++++- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java index 4aeeb1883..5188b1811 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java @@ -8,9 +8,6 @@ import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; - /** * calls {@link WebSocketExceptionHandler} when possible, otherwise logs exception as error */ @@ -29,14 +26,7 @@ public StompExceptionHandler(WebSocketExceptionHandler webSocketExceptionHandler @Nonnull byte[] errorPayload, Throwable cause, StompHeaderAccessor clientHeaderAccessor) { final Message message = MessageBuilder.withPayload(errorPayload).setHeaders(errorHeaderAccessor).build(); - boolean handled = false; - try { - handled = delegate(message, cause); - } catch (InvocationTargetException e) { - LOG.error("Exception thrown during exception handler invocation", e); - } catch (IllegalAccessException unexpected) { - // is checked by delegate - } + final boolean handled = webSocketExceptionHandler.delegate(message, cause); if (!handled) { LOG.error("STOMP sub-protocol exception", cause); @@ -44,36 +34,4 @@ public StompExceptionHandler(WebSocketExceptionHandler webSocketExceptionHandler return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor); } - - /** - * Tries to match method on {@link #webSocketExceptionHandler} - * - * @return true when a method was found and called, false otherwise - * @throws IllegalArgumentException never - */ - private boolean delegate(Message message, Throwable throwable) - throws InvocationTargetException, IllegalAccessException { - if (throwable instanceof Exception exception) { - Method[] methods = webSocketExceptionHandler.getClass().getMethods(); - for (final Method method : methods) { - if (!method.canAccess(webSocketExceptionHandler)) { - continue; - } - Class[] params = method.getParameterTypes(); - if (params.length != 2) { - continue; - } - if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) { - // message, exception - method.invoke(webSocketExceptionHandler, message, exception); - return true; - } else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) { - // exception, message - method.invoke(webSocketExceptionHandler, exception, message); - return true; - } - } - } - return false; - } } diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java index c6042bb9a..b4239a095 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java @@ -11,6 +11,7 @@ import cz.cvut.kbss.termit.exception.InvalidParameterException; import cz.cvut.kbss.termit.exception.InvalidPasswordChangeRequestException; import cz.cvut.kbss.termit.exception.InvalidTermStateException; +import cz.cvut.kbss.termit.exception.JwtException; import cz.cvut.kbss.termit.exception.NotFoundException; import cz.cvut.kbss.termit.exception.PersistenceException; import cz.cvut.kbss.termit.exception.ResourceExistsException; @@ -25,6 +26,7 @@ import cz.cvut.kbss.termit.exception.importing.UnsupportedImportMediaTypeException; import cz.cvut.kbss.termit.exception.importing.VocabularyImportException; import cz.cvut.kbss.termit.rest.handler.ErrorInfo; +import cz.cvut.kbss.termit.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.messaging.Message; @@ -33,12 +35,15 @@ import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.web.bind.annotation.ControllerAdvice; import org.springframework.web.context.request.async.AsyncRequestNotUsableException; import org.springframework.web.multipart.MaxUploadSizeExceededException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.net.URISyntaxException; import static cz.cvut.kbss.termit.util.ExceptionUtils.findCause; @@ -92,8 +97,61 @@ private static ErrorInfo errorInfo(Message message, TermItException e) { e.getParameters()); } + /** + * Tries to match method on this object that matches signature with params + * + * @return true when a method was found and called, false otherwise + * @throws IllegalArgumentException never + */ + public boolean delegate(Message message, Throwable throwable) { + try { + return delegateInternal(message, throwable.getCause()); + } catch (InvocationTargetException invEx) { + LOG.error("Exception thrown during exception handler invocation", invEx); + } catch (IllegalAccessException unexpected) { + // is checked by delegate + } + return false; + } + + /** + * Tries to match method on this object that matches signature with params + * + * @return true when a method was found and called, false otherwise + * @throws IllegalArgumentException never + */ + private boolean delegateInternal(Message message, Throwable throwable) + throws InvocationTargetException, IllegalAccessException { + if (throwable instanceof Exception exception) { + Method[] methods = this.getClass().getMethods(); + for (final Method method : methods) { + if (!method.canAccess(this) || method.getName().equals("delegate") || method.getName().equals("delegateInternal")) { + continue; + } + Class[] params = method.getParameterTypes(); + if (params.length != 2) { + continue; + } + if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) { + // message, exception + method.invoke(this, message, exception); + return true; + } else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) { + // exception, message + method.invoke(this, exception, message); + return true; + } + } + } + return false; + } + @MessageExceptionHandler public void messageDeliveryException(Message message, MessageDeliveryException e) { + if (!(e.getCause() instanceof MessageDeliveryException) && delegate(message, e.getCause())) { + return; + } + // messages without destination will be logged only on trace (hasDestination(message) ? LOG.atError() : LOG.atTrace()) .setMessage("Failed to send message with destination {}: {}") @@ -144,11 +202,16 @@ public ErrorInfo authorizationException(Message message, AuthorizationExcepti return errorInfo(message, e); } - @MessageExceptionHandler(AuthenticationException.class) + @MessageExceptionHandler({AuthenticationException.class, AuthenticationServiceException.class}) public ErrorInfo authenticationException(Message message, AuthenticationException e) { - LOG.atDebug().setCause(e).log(e.getMessage()); - LOG.atError().setMessage("Authentication failure during message processing: {}\nMessage: {}") + LOG.atWarn().setMessage("Authentication failure during message processing: {}\nMessage: {}") .addArgument(e.getMessage()).addArgument(message::toString).log(); + + if (ExceptionUtils.findCause(e, JwtException.class).isPresent()) { + return errorInfo(message, e); + } + + LOG.atDebug().setCause(e).log(e.getMessage()); return errorInfo(message, e); } From 7b7a2c266e53d8f600d1de906d3045257664404b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Ka=C5=88ka?= Date: Sat, 30 Nov 2024 13:22:44 +0100 Subject: [PATCH 2/3] [Bug #303] Update websocket security tests to reflect new exception handling --- .../websocket/IntegrationWebSocketSecurityTest.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java index c1d2d81b0..0d47b608d 100644 --- a/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java +++ b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java @@ -37,6 +37,7 @@ import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.notNull; import static org.mockito.Mockito.verify; @@ -86,7 +87,8 @@ void connectionIsClosedOnAnyMessageBeforeConnect(String stompCommand, Boolean wi assertTrue(receivedError.get()); assertFalse(session.isOpen()); assertFalse(receivedReply.get()); - verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull()); + verify(webSocketExceptionHandler).delegate(notNull(), notNull()); + verify(webSocketExceptionHandler).accessDeniedException(notNull(), notNull()); } WebSocketHandler makeWebSocketHandler(AtomicBoolean receivedReply, AtomicBoolean receivedError) { @@ -131,7 +133,8 @@ void connectWithInvalidAuthorizationIsRejected() throws Throwable { assertTrue(receivedError.get()); assertFalse(session.isOpen()); assertFalse(receivedReply.get()); - verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull()); + verify(webSocketExceptionHandler).delegate(notNull(), notNull()); + verify(webSocketExceptionHandler).authenticationException(notNull(), notNull()); } /** @@ -167,7 +170,8 @@ void connectWithInvalidJwtAuthorizationIsRejected() throws Throwable { assertFalse(session.isOpen()); assertFalse(receivedReply.get()); - verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull()); + verify(webSocketExceptionHandler).delegate(notNull(), notNull()); + verify(webSocketExceptionHandler).authenticationException(notNull(), notNull()); } /** @@ -186,5 +190,6 @@ void connectionIsNotClosedWhenConnectMessageIsSent() throws Throwable { assertTrue(session.isConnected()); session.disconnect(); await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isConnected()); + verify(webSocketExceptionHandler).delegate(notNull(), isNull()); } } From 2e9e2245bd0e49b8724726fc2230f49be5c44e17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Ka=C5=88ka?= Date: Sun, 8 Dec 2024 11:09:47 +0100 Subject: [PATCH 3/3] [Bug #303] Add comments to WebSocketExceptionHandler#delegate methods and additionally check exception based on annotation value --- .../handler/StompExceptionHandler.java | 12 +++-- .../handler/WebSocketExceptionHandler.java | 54 ++++++++++++++----- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java index 5188b1811..d471f242f 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java @@ -1,6 +1,7 @@ package cz.cvut.kbss.termit.websocket.handler; import jakarta.annotation.Nonnull; +import jakarta.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.messaging.Message; @@ -23,10 +24,15 @@ public StompExceptionHandler(WebSocketExceptionHandler webSocketExceptionHandler @Override protected @Nonnull Message handleInternal(@Nonnull StompHeaderAccessor errorHeaderAccessor, - @Nonnull byte[] errorPayload, Throwable cause, - StompHeaderAccessor clientHeaderAccessor) { + @Nonnull byte[] errorPayload, + @Nullable Throwable cause, + @Nullable StompHeaderAccessor clientHeaderAccessor) { final Message message = MessageBuilder.withPayload(errorPayload).setHeaders(errorHeaderAccessor).build(); - final boolean handled = webSocketExceptionHandler.delegate(message, cause); + Throwable causeToHandle = cause; + if (causeToHandle != null && causeToHandle.getCause() != null) { + causeToHandle = causeToHandle.getCause(); + } + final boolean handled = webSocketExceptionHandler.delegate(message, causeToHandle); if (!handled) { LOG.error("STOMP sub-protocol exception", cause); diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java index b4239a095..dd5c574ad 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java @@ -45,11 +45,15 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; import static cz.cvut.kbss.termit.util.ExceptionUtils.findCause; /** - * @implSpec Should reflect {@link cz.cvut.kbss.termit.rest.handler.RestExceptionHandler} + * @implSpec Should reflect {@link cz.cvut.kbss.termit.rest.handler.RestExceptionHandler}.
+ * In order for the delegation to work, the method signature of MessageExceptionHandler methods must be {@code (Message, Exception)} */ @SendToUser @ControllerAdvice @@ -98,51 +102,73 @@ private static ErrorInfo errorInfo(Message message, TermItException e) { } /** - * Tries to match method on this object that matches signature with params + * Searches available methods annotated with {@link MessageExceptionHandler} in this class + * when the method signature matches {@code (Message, Exception)} + * and the exception parameter is assignable from the supplied throwable + * the method is called. * + * @param message the associated message + * @param throwable the exception to handle * @return true when a method was found and called, false otherwise - * @throws IllegalArgumentException never */ public boolean delegate(Message message, Throwable throwable) { try { - return delegateInternal(message, throwable.getCause()); + return delegateInternal(message, throwable); } catch (InvocationTargetException invEx) { + // Exception handler method threw an exception LOG.error("Exception thrown during exception handler invocation", invEx); } catch (IllegalAccessException unexpected) { - // is checked by delegate + // is checked by delegateInternal } return false; } /** - * Tries to match method on this object that matches signature with params + * Searches available methods annotated with {@link MessageExceptionHandler} in this class + * when the method signature matches {@code (Message, Exception)} + * and the exception parameter is assignable from the supplied throwable + * the method is called. * + * @param message the associated message + * @param throwable the exception to handle * @return true when a method was found and called, false otherwise * @throws IllegalArgumentException never + * @throws IllegalAccessException never + * @throws InvocationTargetException when the exception handler method throws an exception */ private boolean delegateInternal(Message message, Throwable throwable) throws InvocationTargetException, IllegalAccessException { + // handle only exceptions if (throwable instanceof Exception exception) { - Method[] methods = this.getClass().getMethods(); + // find all methods annotated with MessageExceptionHandler + List methods = Arrays.stream(this.getClass().getMethods()) + .filter(m -> m.isAnnotationPresent(MessageExceptionHandler.class)).toList(); for (final Method method : methods) { - if (!method.canAccess(this) || method.getName().equals("delegate") || method.getName().equals("delegateInternal")) { + // check for reflection access to prevent IllegalAccessException + if (!method.canAccess(this)) { continue; } + // we are interested only in methods with exactly two parameters (message, exception) Class[] params = method.getParameterTypes(); if (params.length != 2) { continue; } + // check if the MessageExceptionHandler annotation has value with allowed exceptions + Class[] allowedExceptions = Optional.ofNullable(method.getAnnotation(MessageExceptionHandler.class)) + .map(MessageExceptionHandler::value).orElseGet(() -> new Class[0]); + // if the exception is not allowed by the annotation, skip the method + if (allowedExceptions.length > 0 && Arrays.stream(allowedExceptions).noneMatch(e -> e.isAssignableFrom(exception.getClass()))) { + continue; + } + // validate the method signature if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) { - // message, exception + // call the method with message, exception parameters method.invoke(this, message, exception); - return true; - } else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) { - // exception, message - method.invoke(this, exception, message); - return true; + return true; // exception was handled } } } + // throwable is not an exception or no suitable method was found return false; }