diff --git a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/handlers/WebsocketUtilTestCase.java b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/handlers/WebsocketUtilTestCase.java index dbf8a94b9b56..a11854b773a8 100644 --- a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/handlers/WebsocketUtilTestCase.java +++ b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/handlers/WebsocketUtilTestCase.java @@ -138,6 +138,16 @@ public void testIsThrottled() { Assert.assertTrue(WebsocketUtil.isThrottled(resourceKey, subscriptionKey, apiKey)); } + @Test + public void testGetThrottleStatus() { + ThrottleDataHolder throttleDataHolder = Mockito.mock(ThrottleDataHolder.class); + Mockito.when(serviceReferenceHolder.getThrottleDataHolder()).thenReturn(throttleDataHolder); + Mockito.when(throttleDataHolder.isAPIThrottled(apiKey)).thenReturn(true); + Mockito.when(throttleDataHolder.isAPIThrottled(resourceKey)).thenReturn(true); + Mockito.when(throttleDataHolder.isAPIThrottled(subscriptionKey)).thenReturn(true); + Assert.assertTrue(WebsocketUtil.getThrottleStatus(resourceKey, subscriptionKey, apiKey).isThrottled()); + } + @Test public void testGetAccessTokenCacheKey() { Assert.assertEquals("235erwytgtkyb:/ishara:/resource", diff --git a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/inbound/websocket/utils/InboundWebsocketProcessorUtilTest.java b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/inbound/websocket/utils/InboundWebsocketProcessorUtilTest.java index 4341e986f98b..22b18080f361 100644 --- a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/inbound/websocket/utils/InboundWebsocketProcessorUtilTest.java +++ b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/inbound/websocket/utils/InboundWebsocketProcessorUtilTest.java @@ -32,6 +32,7 @@ import org.testng.Assert; import org.wso2.carbon.apimgt.common.gateway.constants.GraphQLConstants; import org.wso2.carbon.apimgt.api.APIManagementException; +import org.wso2.carbon.apimgt.gateway.dto.WebSocketThrottleResponseDTO; import org.wso2.carbon.apimgt.gateway.handlers.Utils; import org.wso2.carbon.apimgt.gateway.handlers.WebsocketUtil; import org.wso2.carbon.apimgt.gateway.handlers.WebsocketWSClient; @@ -154,8 +155,8 @@ public void testDoThrottleSuccessForGraphQL() throws ParseException { + inboundMessageContext.getApiContext() + ":" + inboundMessageContext.getVersion(); String applicationLevelThrottleKey = apiKeyValidationInfoDTO.getApplicationId() + ":" + apiKeyValidationInfoDTO.getSubscriber() + "@" + apiKeyValidationInfoDTO.getSubscriberTenantDomain(); - PowerMockito.when(WebsocketUtil.isThrottled(verbInfoDTO.getRequestKey(), subscriptionLevelThrottleKey, - applicationLevelThrottleKey)).thenReturn(false); + PowerMockito.when(WebsocketUtil.getThrottleStatus(verbInfoDTO.getRequestKey(), subscriptionLevelThrottleKey, + applicationLevelThrottleKey)).thenReturn(null); Mockito.when(dataPublisher.tryPublish(Mockito.anyObject())).thenReturn(true); InboundProcessorResponseDTO inboundProcessorResponseDTO = InboundWebsocketProcessorUtil.doThrottleForGraphQL(msgSize, verbInfoDTO, inboundMessageContext, @@ -185,18 +186,22 @@ public void testDoThrottleFail() throws ParseException { inboundMessageContext.setVersion("1.0.0"); inboundMessageContext.setUserIP("198.162.10.2"); inboundMessageContext.setInfoDTO(apiKeyValidationInfoDTO); + WebSocketThrottleResponseDTO throttleResponseDTO = new WebSocketThrottleResponseDTO(); + throttleResponseDTO.setThrottled(true); + throttleResponseDTO.setThrottledOutReason("Throttled due to application-level constraints"); - String subscriptionLevelThrottleKey = apiKeyValidationInfoDTO.getApplicationId() + ":" - + inboundMessageContext.getApiContext() + ":" + inboundMessageContext.getVersion(); - String applicationLevelThrottleKey = apiKeyValidationInfoDTO.getApplicationId() + ":" - + apiKeyValidationInfoDTO.getSubscriber() + "@" + apiKeyValidationInfoDTO.getSubscriberTenantDomain(); + String subscriptionLevelThrottleKey = + apiKeyValidationInfoDTO.getApplicationId() + ":" + inboundMessageContext.getApiContext() + ":" + + inboundMessageContext.getVersion(); + String applicationLevelThrottleKey = + apiKeyValidationInfoDTO.getApplicationId() + ":" + apiKeyValidationInfoDTO.getSubscriber() + "@" + + apiKeyValidationInfoDTO.getSubscriberTenantDomain(); Mockito.when(dataPublisher.tryPublish(Mockito.anyObject())).thenReturn(true); - PowerMockito.when(WebsocketUtil.isThrottled(verbInfoDTO.getRequestKey(), subscriptionLevelThrottleKey, - applicationLevelThrottleKey)).thenReturn(true); - InboundProcessorResponseDTO inboundProcessorResponseDTO = - InboundWebsocketProcessorUtil.doThrottleForGraphQL(msgSize, verbInfoDTO, inboundMessageContext, - operationId); + PowerMockito.when(WebsocketUtil.getThrottleStatus(verbInfoDTO.getRequestKey(), subscriptionLevelThrottleKey, + applicationLevelThrottleKey)).thenReturn(throttleResponseDTO); + InboundProcessorResponseDTO inboundProcessorResponseDTO = InboundWebsocketProcessorUtil.doThrottleForGraphQL( + msgSize, verbInfoDTO, inboundMessageContext, operationId); Assert.assertTrue(inboundProcessorResponseDTO.isError()); Assert.assertEquals(inboundProcessorResponseDTO.getErrorMessage(), WebSocketApiConstants.FrameErrorConstants.THROTTLED_OUT_ERROR_MESSAGE);