Skip to content

Commit

Permalink
Move the ml_parameters from XContent to the request parameters to avo…
Browse files Browse the repository at this point in the history
…id the conflict with search XContent input. (opensearch-project#71)

* Create JvmService instance on demand.

Signed-off-by: Alex <[email protected]>

* Move the ml_parameters from XContent to the request parameters to avoid
the conflict with search XContent input.

Signed-off-by: Alex <[email protected]>

Co-authored-by: Alex <[email protected]>
  • Loading branch information
2 people authored and jackiehanyang committed Nov 16, 2021
1 parent 74271e6 commit c2133bc
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 61 deletions.
49 changes: 20 additions & 29 deletions plugin/src/main/java/org/opensearch/ml/rest/BaseMLSearchAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.IntConsumer;
import java.util.stream.Collectors;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.ParsingException;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.parameter.MLParameter;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.search.RestSearchAction;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

Expand Down Expand Up @@ -103,36 +103,27 @@ SearchQueryInputDataset buildSearchQueryInput(RestRequest request, NodeClient cl
* @return MLParameter list
*/
@VisibleForTesting
List<MLParameter> getMLParameters(RestRequest request) throws IOException {
List<MLParameter> getMLParameters(RestRequest request) {
String parametersStr = request.param(ML_PARAMETERS);
List<MLParameter> parameters = new ArrayList<>();
if (request.hasContent()) {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
if (ML_PARAMETERS.equals(fieldName)) {
Map<String, Object> uiMetadata = parser.map();
parameters = uiMetadata
.entrySet()
.stream()
.map(e -> new MLParameter(e.getKey(), e.getValue()))
.collect(Collectors.toList());
break;
}
}
if (Strings.isNullOrEmpty(parametersStr)) {
return parameters;
}

return parameters;
}
if (parametersStr.charAt(0) != '{' && parametersStr.charAt(parametersStr.length() - 1) != '}') {
parametersStr = "{" + parametersStr + "}";
}

private void ensureExpectedToken(XContentParser.Token expected, XContentParser.Token actual, XContentParser parser) {
if (actual != expected) {
throw new ParsingException(
parser.getTokenLocation(),
String.format(Locale.ROOT, "Failed to parse object: expecting token of type [%s] but found [%s]", expected, actual),
new Object[0]
);
ObjectMapper mapper = new ObjectMapper();
try {
// Convert Map to JSON
Map<String, Object> map = mapper.readValue(parametersStr, new TypeReference<Map<String, Object>>() {
});
parameters = map.entrySet().stream().map(e -> new MLParameter(e.getKey(), e.getValue())).collect(Collectors.toList());
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("invalid ml_parameter: expected key=\"value\" or key=value [" + parametersStr + "]", e);
}

return parameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.rest.BaseMLSearchAction.ML_PARAMETERS;
import static org.opensearch.ml.rest.BaseMLSearchAction.PARAMETER_ALGORITHM;
import static org.opensearch.ml.rest.BaseMLSearchAction.PARAMETER_MODEL_ID;

Expand Down Expand Up @@ -151,63 +150,94 @@ public void testGetMLParametersWithoutRequestBody() throws IOException {

@Test
public void testGetMLParametersWithoutInput() throws IOException {
XContentBuilder xContentBuilder = XContentFactory
.jsonBuilder()
.startObject()
.startObject("type1")
.startObject("properties")
.startObject("location")
.field("type", "geo_point")
.endObject()
.endObject()
.endObject()
.endObject();

Map<String, String> param = ImmutableMap.of();
FakeRestRequest fakeRestRequest = buildFakeRestRequest(param, xContentBuilder);
Map<String, String> param = ImmutableMap
.<String, String>builder()
.put(PARAMETER_ALGORITHM, "kmeans")
.put("index", "index1,index2")
.put("q", "user:dilbert")
.build();
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build();

List<MLParameter> mlParameters = baseMLSearchAction.getMLParameters(fakeRestRequest);
assertTrue(mlParameters.isEmpty());
}

@Test
public void testGetMLParametersWithEmptyInput() throws IOException {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject(ML_PARAMETERS).endObject().endObject();

Map<String, String> param = ImmutableMap.of();
FakeRestRequest fakeRestRequest = buildFakeRestRequest(param, xContentBuilder);
Map<String, String> param = ImmutableMap
.<String, String>builder()
.put(PARAMETER_ALGORITHM, "kmeans")
.put("index", "index1,index2")
.put("ml_parameters", "")
.build();
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build();

List<MLParameter> mlParameters = baseMLSearchAction.getMLParameters(fakeRestRequest);
assertTrue(mlParameters.isEmpty());
}

@Test
public void testGetMLParametersWithValidInput() throws IOException {
XContentBuilder xContentBuilder = XContentFactory
.jsonBuilder()
.startObject()
.startObject(ML_PARAMETERS)
.field("paramName1", "value1")
.field("paramName2", 123)
.endObject()
.endObject();
public void testGetMLParametersWithInvalidJsonInput() throws IOException {
thrown.expect(IllegalArgumentException.class);
Map<String, String> param = ImmutableMap
.<String, String>builder()
.put(PARAMETER_ALGORITHM, "kmeans")
.put("index", "index1,index2")
.put("ml_parameters", "{paramName1\":\"value1\",\"paramName2\":123}")
.build();
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build();

Map<String, String> param = ImmutableMap.of();
FakeRestRequest fakeRestRequest = buildFakeRestRequest(param, xContentBuilder);
baseMLSearchAction.getMLParameters(fakeRestRequest);
}

@Test
public void testGetMLParametersWithValidInput() throws IOException {
Map<String, String> param = ImmutableMap
.<String, String>builder()
.put(PARAMETER_ALGORITHM, "kmeans")
.put("index", "index1,index2")
.put("ml_parameters", "{\"paramName1\":\"value1\",\"paramName2\":123}")
.build();
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build();

List<MLParameter> mlParameters = baseMLSearchAction.getMLParameters(fakeRestRequest);
assertFalse(mlParameters.isEmpty());
assertEquals(2, mlParameters.size());

MLParameter mlParam2 = mlParameters.get(0);
MLParameter mlParam1 = mlParameters.get(0);
assertNotNull(mlParam1);
assertEquals("paramName1", mlParam1.getName());
assertEquals("value1", mlParam1.getValue());

MLParameter mlParam2 = mlParameters.get(1);
assertNotNull(mlParam2);
assertEquals("paramName2", mlParam2.getName());
assertEquals(123, mlParam2.getValue());
}

MLParameter mlParam1 = mlParameters.get(1);
@Test
public void testGetMLParametersWithoutBrace() throws IOException {
Map<String, String> param = ImmutableMap
.<String, String>builder()
.put(PARAMETER_ALGORITHM, "kmeans")
.put("index", "index1,index2")
.put("ml_parameters", "\"paramName1\":\"value1\",\"paramName2\":123")
.build();
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()).withParams(param).build();

List<MLParameter> mlParameters = baseMLSearchAction.getMLParameters(fakeRestRequest);
assertFalse(mlParameters.isEmpty());
assertEquals(2, mlParameters.size());

MLParameter mlParam1 = mlParameters.get(0);
assertNotNull(mlParam1);
assertEquals("paramName1", mlParam1.getName());
assertEquals("value1", mlParam1.getValue());

MLParameter mlParam2 = mlParameters.get(1);
assertNotNull(mlParam2);
assertEquals("paramName2", mlParam2.getName());
assertEquals(123, mlParam2.getValue());
}

@Test
Expand Down

0 comments on commit c2133bc

Please sign in to comment.