Skip to content

Commit

Permalink
add error message for non-torchscript cross encoders
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Nov 28, 2023
1 parent 0731463 commit cd4a583
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.opensearch.ml.engine.algorithms.text_similarity;

import static org.opensearch.ml.engine.ModelHelper.PYTORCH_ENGINE;

import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -58,8 +60,15 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla
}

@Override
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) {
return new TextSimilarityTranslator();
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException {
if(PYTORCH_ENGINE.equals(engine)) {
return new TextSimilarityTranslator();
} else {
throw new IllegalArgumentException("Wrong deep learning engine ["
+ engine
+ "]. Only TORCH_SCRIPT is supported for function name "
+ FunctionName.TEXT_SIMILARITY.name());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,22 @@ public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxExcepti
textSimilarityCrossEncoderModel.close();
}

@Test
public void initModel_predict_ONNX_CrossEncoder_ThenFail() throws URISyntaxException {
model = MLModel.builder()
.modelFormat(MLModelFormat.ONNX)
.name("test_model_name")
.modelId("test_model_id")
.algorithm(FunctionName.TEXT_SIMILARITY)
.version("1.0.0")
.modelState(MLModelState.TRAINED)
.build();
assertThrows("Wrong deep learning engine [OnnxRuntime]. Only TORCH_SCRIPT is supported for function name TEXT_SIMILARITY",
MLException.class,
() -> textSimilarityCrossEncoderModel.initModel(model, params, encryptor)
);
}

@Test
public void initModel_NullModelHelper() throws URISyntaxException {
Map<String, Object> params = new HashMap<>();
Expand Down Expand Up @@ -265,7 +281,7 @@ public void predict_AfterModelClosed() {
log.info(e.getMessage());
assert (e.getMessage().startsWith("Failed to inference TEXT_SIMILARITY"));
}

@After
public void tearDown() {
FileUtils.deleteFileQuietly(mlCachePath);
Expand Down

0 comments on commit cd4a583

Please sign in to comment.