From c2133bceacca8d6a0b8727565a5d64b65b559b7b Mon Sep 17 00:00:00 2001 From: Alex Sun <52179851+spbjss@users.noreply.github.com> Date: Wed, 25 Aug 2021 15:57:48 -0700 Subject: [PATCH] Move the ml_parameters from XContent to the request parameters to avoid the conflict with search XContent input. (#71) * Create JvmService instance on demand. Signed-off-by: Alex * Move the ml_parameters from XContent to the request parameters to avoid the conflict with search XContent input. Signed-off-by: Alex Co-authored-by: Alex --- .../ml/rest/BaseMLSearchAction.java | 49 ++++------ .../ml/rest/BaseMLSearchActionTests.java | 94 ++++++++++++------- 2 files changed, 82 insertions(+), 61 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/rest/BaseMLSearchAction.java b/plugin/src/main/java/org/opensearch/ml/rest/BaseMLSearchAction.java index 24b076d857..00ad6b8638 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/BaseMLSearchAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/BaseMLSearchAction.java @@ -4,21 +4,21 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.function.IntConsumer; import java.util.stream.Collectors; import org.opensearch.action.search.SearchRequest; import org.opensearch.client.node.NodeClient; -import org.opensearch.common.ParsingException; import org.opensearch.common.Strings; -import org.opensearch.common.xcontent.XContentParser; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.parameter.MLParameter; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.search.RestSearchAction; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; @@ -103,36 +103,27 @@ SearchQueryInputDataset buildSearchQueryInput(RestRequest request, NodeClient cl * @return MLParameter list */ @VisibleForTesting - List getMLParameters(RestRequest request) throws IOException { + List getMLParameters(RestRequest request) { + String parametersStr = request.param(ML_PARAMETERS); List parameters = new ArrayList<>(); - if (request.hasContent()) { - XContentParser parser = request.contentParser(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - if (ML_PARAMETERS.equals(fieldName)) { - Map uiMetadata = parser.map(); - parameters = uiMetadata - .entrySet() - .stream() - .map(e -> new MLParameter(e.getKey(), e.getValue())) - .collect(Collectors.toList()); - break; - } - } + if (Strings.isNullOrEmpty(parametersStr)) { + return parameters; } - return parameters; - } + if (parametersStr.charAt(0) != '{' && parametersStr.charAt(parametersStr.length() - 1) != '}') { + parametersStr = "{" + parametersStr + "}"; + } - private void ensureExpectedToken(XContentParser.Token expected, XContentParser.Token actual, XContentParser parser) { - if (actual != expected) { - throw new ParsingException( - parser.getTokenLocation(), - String.format(Locale.ROOT, "Failed to parse object: expecting token of type [%s] but found [%s]", expected, actual), - new Object[0] - ); + ObjectMapper mapper = new ObjectMapper(); + try { + // Convert Map to JSON + Map map = mapper.readValue(parametersStr, new TypeReference>() { + }); + parameters = map.entrySet().stream().map(e -> new MLParameter(e.getKey(), e.getValue())).collect(Collectors.toList()); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("invalid ml_parameter: expected key=\"value\" or key=value [" + parametersStr + "]", e); } + + return parameters; } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/BaseMLSearchActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/BaseMLSearchActionTests.java index 1df4d5aae8..80f039b046 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/BaseMLSearchActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/BaseMLSearchActionTests.java @@ -2,7 +2,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.ml.rest.BaseMLSearchAction.ML_PARAMETERS; import static org.opensearch.ml.rest.BaseMLSearchAction.PARAMETER_ALGORITHM; import static org.opensearch.ml.rest.BaseMLSearchAction.PARAMETER_MODEL_ID; @@ -151,20 +150,13 @@ public void testGetMLParametersWithoutRequestBody() throws IOException { @Test public void testGetMLParametersWithoutInput() throws IOException { - XContentBuilder xContentBuilder = XContentFactory - .jsonBuilder() - .startObject() - .startObject("type1") - .startObject("properties") - .startObject("location") - .field("type", "geo_point") - .endObject() - .endObject() - .endObject() - .endObject(); - - Map param = ImmutableMap.of(); - FakeRestRequest fakeRestRequest = buildFakeRestRequest(param, xContentBuilder); + Map param = ImmutableMap + .builder() + .put(PARAMETER_ALGORITHM, "kmeans") + .put("index", "index1,index2") + .put("q", "user:dilbert") + .build(); + FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build(); List mlParameters = baseMLSearchAction.getMLParameters(fakeRestRequest); assertTrue(mlParameters.isEmpty()); @@ -172,42 +164,80 @@ public void testGetMLParametersWithoutInput() throws IOException { @Test public void testGetMLParametersWithEmptyInput() throws IOException { - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject(ML_PARAMETERS).endObject().endObject(); - - Map param = ImmutableMap.of(); - FakeRestRequest fakeRestRequest = buildFakeRestRequest(param, xContentBuilder); + Map param = ImmutableMap + .builder() + .put(PARAMETER_ALGORITHM, "kmeans") + .put("index", "index1,index2") + .put("ml_parameters", "") + .build(); + FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build(); List mlParameters = baseMLSearchAction.getMLParameters(fakeRestRequest); assertTrue(mlParameters.isEmpty()); } @Test - public void testGetMLParametersWithValidInput() throws IOException { - XContentBuilder xContentBuilder = XContentFactory - .jsonBuilder() - .startObject() - .startObject(ML_PARAMETERS) - .field("paramName1", "value1") - .field("paramName2", 123) - .endObject() - .endObject(); + public void testGetMLParametersWithInvalidJsonInput() throws IOException { + thrown.expect(IllegalArgumentException.class); + Map param = ImmutableMap + .builder() + .put(PARAMETER_ALGORITHM, "kmeans") + .put("index", "index1,index2") + .put("ml_parameters", "{paramName1\":\"value1\",\"paramName2\":123}") + .build(); + FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build(); - Map param = ImmutableMap.of(); - FakeRestRequest fakeRestRequest = buildFakeRestRequest(param, xContentBuilder); + baseMLSearchAction.getMLParameters(fakeRestRequest); + } + + @Test + public void testGetMLParametersWithValidInput() throws IOException { + Map param = ImmutableMap + .builder() + .put(PARAMETER_ALGORITHM, "kmeans") + .put("index", "index1,index2") + .put("ml_parameters", "{\"paramName1\":\"value1\",\"paramName2\":123}") + .build(); + FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build(); List mlParameters = baseMLSearchAction.getMLParameters(fakeRestRequest); assertFalse(mlParameters.isEmpty()); assertEquals(2, mlParameters.size()); - MLParameter mlParam2 = mlParameters.get(0); + MLParameter mlParam1 = mlParameters.get(0); + assertNotNull(mlParam1); + assertEquals("paramName1", mlParam1.getName()); + assertEquals("value1", mlParam1.getValue()); + + MLParameter mlParam2 = mlParameters.get(1); assertNotNull(mlParam2); assertEquals("paramName2", mlParam2.getName()); assertEquals(123, mlParam2.getValue()); + } - MLParameter mlParam1 = mlParameters.get(1); + @Test + public void testGetMLParametersWithoutBrace() throws IOException { + Map param = ImmutableMap + .builder() + .put(PARAMETER_ALGORITHM, "kmeans") + .put("index", "index1,index2") + .put("ml_parameters", "\"paramName1\":\"value1\",\"paramName2\":123") + .build(); + FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build(); + + List mlParameters = baseMLSearchAction.getMLParameters(fakeRestRequest); + assertFalse(mlParameters.isEmpty()); + assertEquals(2, mlParameters.size()); + + MLParameter mlParam1 = mlParameters.get(0); assertNotNull(mlParam1); assertEquals("paramName1", mlParam1.getName()); assertEquals("value1", mlParam1.getValue()); + + MLParameter mlParam2 = mlParameters.get(1); + assertNotNull(mlParam2); + assertEquals("paramName2", mlParam2.getName()); + assertEquals(123, mlParam2.getValue()); } @Test