forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[2.x] Adding an integration test for redeploying a model (opensearch-…
…project#1016) (opensearch-project#1264) * Adding a failing integ test for redeploy model and fix breaking changes from OpenSearch core * Adding model group ID changes for tests * Fixing tests for ImmutableMap copy * Commenting wait out task for model * Adding a failing integ test for redeploy model and fix breaking changes from OpenSearch core * Rebasing with 2.x * Adding logs to debug the test in GHA * GHA tests * Still debugging * Removing comment * Removing unnecessary changes * Removing logs --------- Signed-off-by: Sarat Vemulapalli <[email protected]> Co-authored-by: Sarat Vemulapalli <[email protected]>
- Loading branch information
1 parent
adb7ed3
commit de860dd
Showing
1 changed file
with
70 additions
and
0 deletions.
There are no files selected for viewing
70 changes: 70 additions & 0 deletions
70
plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.ml.rest; | ||
|
||
import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; | ||
|
||
import java.io.IOException; | ||
import java.util.Map; | ||
|
||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.ml.common.MLTaskState; | ||
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelInput; | ||
import org.opensearch.ml.utils.TestHelper; | ||
|
||
public class RestMLDeployModelActionIT extends MLCommonsRestTestCase { | ||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
private MLRegisterModelInput registerModelInput; | ||
private MLRegisterModelGroupInput mlRegisterModelGroupInput; | ||
private String modelGroupId; | ||
|
||
@Before | ||
public void setup() throws IOException { | ||
mlRegisterModelGroupInput = MLRegisterModelGroupInput.builder().name("testGroupID").description("This is test Group").build(); | ||
registerModelGroup(client(), TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> { | ||
this.modelGroupId = (String) registerModelGroupResult.get("model_group_id"); | ||
}); | ||
registerModelInput = createRegisterModelInput(modelGroupId); | ||
} | ||
|
||
public void testReDeployModel() throws InterruptedException, IOException { | ||
// Register Model | ||
String taskId = registerModel(TestHelper.toJsonString(registerModelInput)); | ||
waitForTask(taskId, MLTaskState.COMPLETED); | ||
getTask(client(), taskId, response -> { | ||
String model_id = (String) response.get(MODEL_ID_FIELD); | ||
try { | ||
// Deploy Model | ||
String taskId1 = deployModel(model_id); | ||
getTask(client(), taskId1, innerResponse -> { assertEquals(model_id, innerResponse.get(MODEL_ID_FIELD)); }); | ||
waitForTask(taskId1, MLTaskState.COMPLETED); | ||
|
||
// Undeploy Model | ||
Map<String, Object> undeployresponse = undeployModel(model_id); | ||
for (Map.Entry<String, Object> entry : undeployresponse.entrySet()) { | ||
Map stats = (Map) ((Map) entry.getValue()).get("stats"); | ||
assertEquals("undeployed", stats.get(model_id)); | ||
} | ||
|
||
// Deploy Model again | ||
taskId1 = deployModel(model_id); | ||
getTask(client(), taskId1, innerResponse -> { logger.info("Re-Deploy model {}", innerResponse); }); | ||
waitForTask(taskId1, MLTaskState.COMPLETED); | ||
|
||
getModel(client(), model_id, model -> { | ||
logger.info("Get Model after re-deploy {}", model); | ||
assertEquals("DEPLOYED", model.get("model_state")); | ||
}); | ||
|
||
} catch (Exception e) { | ||
throw new RuntimeException(e); | ||
} | ||
}); | ||
} | ||
} |