From 0ea3e3af0ba15a5444adb6672b8e0e2480a7202e Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 1 Aug 2023 00:24:06 -0700 Subject: [PATCH] remote inference: escape parameter if not valid json (#1176) Signed-off-by: Yaliang Wu --- .../algorithms/remote/ConnectorUtils.java | 15 ++++++- .../algorithms/remote/ConnectorUtilsTest.java | 40 +++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 7eccd6155d..588da7ccae 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -15,6 +15,7 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.script.ScriptService; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentials; @@ -34,7 +35,7 @@ import java.util.Map; import java.util.Optional; -import static org.apache.commons.text.StringEscapeUtils.escapeJava; +import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; import static org.opensearch.ml.engine.utils.ScriptUtils.executePostprocessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction; @@ -95,6 +96,18 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } else { throw new IllegalArgumentException("Wrong input type"); } + if (inputData.getParameters() != null) { + Map newParameters = new HashMap<>(); + inputData.getParameters().entrySet().forEach(entry -> { + if (StringUtils.isJson(entry.getValue())) { + // no need to escape if it's already valid json + newParameters.put(entry.getKey(), entry.getValue()); + } else { + newParameters.put(entry.getKey(), escapeJson(entry.getValue())); + } + }); + inputData.setParameters(newParameters); + } return inputData; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 857cbe997f..bfe5023b9a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -71,6 +71,46 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); } + @Test + public void processInput_RemoteInferenceInputDataSet_EscapeString() { + String input = "hello \"world\" \n \t"; + String expectedInput = "hello \\\"world\\\" \\n \\t"; + processInput_RemoteInferenceInputDataSet(input, expectedInput); + } + + @Test + public void processInput_RemoteInferenceInputDataSet_NotEscapeStringValue() { + String input = "test value"; + processInput_RemoteInferenceInputDataSet(input, input); + } + + @Test + public void processInput_RemoteInferenceInputDataSet_NotEscapeArrayString() { + String input = "[\"test value1\"]"; + processInput_RemoteInferenceInputDataSet(input, input); + } + + @Test + public void processInput_RemoteInferenceInputDataSet_NotEscapeJsonString() { + String input = "{\"key1\": \"value\", \"key2\": 123}"; + processInput_RemoteInferenceInputDataSet(input, input); + } + + private void processInput_RemoteInferenceInputDataSet(String input, String expectedInput) { + RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", input)).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build(); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input")); + } + @Test public void processInput_TextDocsInputDataSet_PreprocessFunction_OneTextDoc() { processInput_TextDocsInputDataSet_PreprocessFunction(