Skip to content

Commit

Permalink
remote inference: escape parameter if not valid json (opensearch-proj…
Browse files Browse the repository at this point in the history
…ect#1176)

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored and zane-neo committed Sep 1, 2023
1 parent 49315fd commit 0ea3e3a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.script.ScriptService;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
Expand All @@ -34,7 +35,7 @@
import java.util.Map;
import java.util.Optional;

import static org.apache.commons.text.StringEscapeUtils.escapeJava;
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD;
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostprocessFunction;
import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction;
Expand Down Expand Up @@ -95,6 +96,18 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
} else {
throw new IllegalArgumentException("Wrong input type");
}
if (inputData.getParameters() != null) {
Map<String, String> newParameters = new HashMap<>();
inputData.getParameters().entrySet().forEach(entry -> {
if (StringUtils.isJson(entry.getValue())) {
// no need to escape if it's already valid json
newParameters.put(entry.getKey(), entry.getValue());
} else {
newParameters.put(entry.getKey(), escapeJson(entry.getValue()));
}
});
inputData.setParameters(newParameters);
}
return inputData;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,46 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() {
ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
}

@Test
public void processInput_RemoteInferenceInputDataSet_EscapeString() {
String input = "hello \"world\" \n \t";
String expectedInput = "hello \\\"world\\\" \\n \\t";
processInput_RemoteInferenceInputDataSet(input, expectedInput);
}

@Test
public void processInput_RemoteInferenceInputDataSet_NotEscapeStringValue() {
String input = "test value";
processInput_RemoteInferenceInputDataSet(input, input);
}

@Test
public void processInput_RemoteInferenceInputDataSet_NotEscapeArrayString() {
String input = "[\"test value1\"]";
processInput_RemoteInferenceInputDataSet(input, input);
}

@Test
public void processInput_RemoteInferenceInputDataSet_NotEscapeJsonString() {
String input = "{\"key1\": \"value\", \"key2\": 123}";
processInput_RemoteInferenceInputDataSet(input, input);
}

private void processInput_RemoteInferenceInputDataSet(String input, String expectedInput) {
RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", input)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build();

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input"));
}

@Test
public void processInput_TextDocsInputDataSet_PreprocessFunction_OneTextDoc() {
processInput_TextDocsInputDataSet_PreprocessFunction(
Expand Down

0 comments on commit 0ea3e3a

Please sign in to comment.