diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java index c9ce15e49f..93b32ec212 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java @@ -46,7 +46,7 @@ public TextSimilarityInputDataSet(String queryText, List textDocs) { Objects.requireNonNull(textDocs); Objects.requireNonNull(queryText); if(textDocs.isEmpty()) { - throw new IllegalArgumentException("pairs must be nonempty"); + throw new IllegalArgumentException("No text documents provided"); } this.textDocs = textDocs; this.queryText = queryText; diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java index 88ed7792dd..0c4d9f9a7b 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java @@ -105,10 +105,10 @@ public TextSimilarityMLInput(XContentParser parser, FunctionName functionName) t } } if(docs.isEmpty()) { - throw new IllegalArgumentException("no text docs"); + throw new IllegalArgumentException("No text documents were provided"); } if(queryText == null) { - throw new IllegalArgumentException("no query text"); + throw new IllegalArgumentException("No query text was provided"); } inputDataset = new TextSimilarityInputDataSet(queryText, docs); } diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java index 2813972b0e..040ed1968c 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java @@ -52,6 +52,14 @@ public void noPairs_ThenFail() { String queryText = "today is sunny"; IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build()); - assert (e.getMessage().equals("pairs must be nonempty")); + assert (e.getMessage().equals("No text documents provided")); + } + + @Test + public void noQuery_ThenFail() { + List docs = List.of("That is a happy dog", "it's summer"); + String queryText = null; + assertThrows(NullPointerException.class, + () -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build()); } } diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java index 8cca2794c7..296b939f5f 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java @@ -116,7 +116,7 @@ public void testParseJson_NoPairs_ThenFail() throws IOException { IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> MLInput.parse(parser, input.getFunctionName().name())); - assert (e.getMessage().equals("no text docs")); + assert (e.getMessage().equals("No text documents were provided")); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java index 73a25a4ee7..97d31eb6f7 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java @@ -189,6 +189,34 @@ public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxExcepti textSimilarityCrossEncoderModel.close(); } + @Test + public void initModel_predict_ONNX_CrossEncoder() throws URISyntaxException { + model = MLModel + .builder() + .modelFormat(MLModelFormat.ONNX) + .name("test_model_name") + .modelId("test_model_id") + .algorithm(FunctionName.TEXT_SIMILARITY) + .version("1.0.0") + .modelState(MLModelState.TRAINED) + .build(); + modelZipFile = new File(getClass().getResource("TinyBERT-CE-onnx.zip").toURI()); + params.put(MODEL_ZIP_FILE, modelZipFile); + + textSimilarityCrossEncoderModel.initModel(model, params, encryptor); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build(); + ModelTensorOutput output = (ModelTensorOutput) textSimilarityCrossEncoderModel.predict(mlInput); + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + for (int i = 0; i < mlModelOutputs.size(); i++) { + ModelTensors tensors = mlModelOutputs.get(i); + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + assertEquals(1, mlModelTensors.get(0).getData().length); + } + textSimilarityCrossEncoderModel.close(); + } + @Test public void initModel_NullModelHelper() throws URISyntaxException { Map params = new HashMap<>(); diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-onnx.zip b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-onnx.zip new file mode 100644 index 0000000000..fd8b7841e1 Binary files /dev/null and b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-onnx.zip differ