diff --git a/backend/src/main/java/com/twtw/backend/config/socket/StompConfig.java b/backend/src/main/java/com/twtw/backend/config/socket/StompConfig.java index db01555a..fafbd77f 100644 --- a/backend/src/main/java/com/twtw/backend/config/socket/StompConfig.java +++ b/backend/src/main/java/com/twtw/backend/config/socket/StompConfig.java @@ -1,11 +1,10 @@ package com.twtw.backend.config.socket; import com.twtw.backend.global.properties.RabbitMQProperties; - import lombok.RequiredArgsConstructor; - import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Profile; +import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.util.AntPathMatcher; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; @@ -17,7 +16,9 @@ @RequiredArgsConstructor @EnableWebSocketMessageBroker public class StompConfig implements WebSocketMessageBrokerConfigurer { + private final RabbitMQProperties rabbitMQProperties; + private final StompHandler stompHandler; @Override public void registerStompEndpoints(final StompEndpointRegistry registry) { @@ -36,4 +37,9 @@ public void configureMessageBroker(final MessageBrokerRegistry registry) { registry.setApplicationDestinationPrefixes("/pub"); } + + @Override + public void configureClientInboundChannel(ChannelRegistration registration) { + registration.interceptors(stompHandler); + } } diff --git a/backend/src/main/java/com/twtw/backend/config/socket/StompHandler.java b/backend/src/main/java/com/twtw/backend/config/socket/StompHandler.java new file mode 100644 index 00000000..4d29ca69 --- /dev/null +++ b/backend/src/main/java/com/twtw/backend/config/socket/StompHandler.java @@ -0,0 +1,48 @@ +package com.twtw.backend.config.socket; + +import com.twtw.backend.config.security.jwt.TokenProvider; +import lombok.RequiredArgsConstructor; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.stereotype.Component; + +import java.util.Optional; + +@Component +@RequiredArgsConstructor +@Order(Ordered.HIGHEST_PRECEDENCE + 99) +public class StompHandler implements ChannelInterceptor { + + private static final String AUTHORIZATION_HEADER = "Authorization"; + private static final String BEARER_PREFIX = "Bearer "; + private final TokenProvider tokenProvider; + + @Override + public Message preSend(final Message message, final MessageChannel channel) { + final StompHeaderAccessor acessor = StompHeaderAccessor.wrap(message); + + if (StompCommand.CONNECT == acessor.getCommand()) { + final Optional headerValue = Optional.ofNullable(acessor.getFirstNativeHeader(AUTHORIZATION_HEADER)); + resolveToken(headerValue).ifPresent(header -> { + tokenProvider.validateToken(header); + SecurityContextHolder.getContext().setAuthentication(tokenProvider.getAuthentication(header)); + }); + } + return message; + } + + private Optional resolveToken(final Optional headerValue) { + return headerValue.map(header -> { + if (header.startsWith(BEARER_PREFIX)) { + return header.substring(7); + } + return header; + }); + } +}