Skip to content

Commit

Permalink
add test that memory cb exception is caught by action listener
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Lindeman <[email protected]>
  • Loading branch information
HenryL27 committed Mar 22, 2024
1 parent 82bbdf4 commit 14cf810
Showing 1 changed file with 19 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.breaker.MemoryCircuitBreaker;
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -112,6 +113,7 @@
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.suppliers.CounterSupplier;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.monitor.jvm.JvmService;
import org.opensearch.script.ScriptService;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -449,6 +451,23 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
}

public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() throws PrivilegedActionException {
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class));
String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!";
when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage));

MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);

ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argCaptor.capture());
Exception e = argCaptor.getValue();
assertTrue(e instanceof MLLimitExceededException);
assertEquals(memCBIsOpenMessage, e.getMessage());
}

public void testIndexRemoteModel() throws PrivilegedActionException {
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
Expand Down

0 comments on commit 14cf810

Please sign in to comment.