Skip to content

Commit

Permalink
parse request body, not params, for post requests
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Aug 29, 2023
1 parent 65dda47 commit 92328e8
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 28 deletions.
1 change: 1 addition & 0 deletions conversational-memory/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies {
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
testImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
testImplementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.opensearch.ml.conversational.action.memory.conversation;

import java.io.IOException;
import java.util.Map;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
Expand Down Expand Up @@ -80,8 +81,12 @@ public ActionRequestValidationException validate() {
* @throws IOException if something breaks
*/
public static CreateConversationRequest fromRestRequest(RestRequest restRequest) throws IOException {
if (restRequest.hasParam(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) {
return new CreateConversationRequest(restRequest.param(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD));
if (!restRequest.hasContent()) {
return new CreateConversationRequest();
}
Map<String, String> body = restRequest.contentParser().mapStrings();
if (body.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) {
return new CreateConversationRequest(body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD));
} else {
return new CreateConversationRequest();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.IOException;
import java.util.Map;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
Expand Down Expand Up @@ -91,12 +92,13 @@ public ActionRequestValidationException validate() {
* @throws IOException if something goes wrong reading from request
*/
public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException {
Map<String, String> body = request.contentParser().mapStrings();
String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD);
String inp = request.param(ActionConstants.INPUT_FIELD);
String prmpt = request.param(ActionConstants.PROMPT_TEMPLATE_FIELD);
String rsp = request.param(ActionConstants.AI_RESPONSE_FIELD);
String ogn = request.param(ActionConstants.RESPONSE_ORIGIN_FIELD);
String metadata = request.param(ActionConstants.METADATA_FIELD);
String inp = body.get(ActionConstants.INPUT_FIELD);
String prmpt = body.get(ActionConstants.PROMPT_TEMPLATE_FIELD);
String rsp = body.get(ActionConstants.AI_RESPONSE_FIELD);
String ogn = body.get(ActionConstants.RESPONSE_ORIGIN_FIELD);
String metadata = body.get(ActionConstants.METADATA_FIELD);
return new CreateInteractionRequest(cid, inp, prmpt, rsp, ogn, metadata);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,32 @@
import java.io.IOException;
import java.util.Map;

import org.junit.Before;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversational.ActionConstants;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;

import com.google.gson.Gson;

public class CreateConversationRequestTests extends OpenSearchTestCase {

Gson gson;

@Before
public void setup() {
gson = new Gson();
}

public void testConstructorsAndStreaming_Named() throws IOException {
CreateConversationRequest request = new CreateConversationRequest("test-name");
assert (request.validate() == null);
Expand Down Expand Up @@ -67,7 +79,7 @@ public void testEmptyRestRequest() throws IOException {
public void testNamedRestRequest() throws IOException {
String name = "test-name";
RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withParams(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name))
.withContent(new BytesArray(gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name))), MediaTypeRegistry.JSON)
.build();
CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req);
assert (request.getName().equals(name));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversational.ActionConstants;
import org.opensearch.rest.RestChannel;
Expand All @@ -36,7 +39,17 @@
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;

import com.google.gson.Gson;

public class CreateConversationRestActionTests extends OpenSearchTestCase {

Gson gson;

@Before
public void setup() {
gson = new Gson();
}

public void testBasics() {
CreateConversationRestAction action = new CreateConversationRestAction();
assert (action.getName().equals("conversational_create_conversation"));
Expand All @@ -48,7 +61,10 @@ public void testBasics() {
public void testPrepareRequest() throws Exception {
CreateConversationRestAction action = new CreateConversationRestAction();
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withParams(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "test-name"))
.withContent(
new BytesArray(gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "test-name"))),
MediaTypeRegistry.JSON
)
.build();

NodeClient client = mock(NodeClient.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,32 @@
import java.io.IOException;
import java.util.Map;

import org.junit.Before;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversational.ActionConstants;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;

import com.google.gson.Gson;

public class CreateInteractionRequestTests extends OpenSearchTestCase {

Gson gson;

@Before
public void setup() {
gson = new Gson();
}

public void testConstructorsAndStreaming() throws IOException {
CreateInteractionRequest request = new CreateInteractionRequest("cid", "input", "pt", "response", "origin", "metadata");
assert (request.validate() == null);
Expand Down Expand Up @@ -64,8 +76,6 @@ public void testNullCID_thenFail() {
public void testFromRestRequest() throws IOException {
Map<String, String> params = Map
.of(
ActionConstants.CONVERSATION_ID_FIELD,
"cid",
ActionConstants.INPUT_FIELD,
"input",
ActionConstants.PROMPT_TEMPLATE_FIELD,
Expand All @@ -77,7 +87,10 @@ public void testFromRestRequest() throws IOException {
ActionConstants.METADATA_FIELD,
"metadata"
);
RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build();
RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid"))
.withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON)
.build();
CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest);
assert (request.validate() == null);
assert (request.getConversationId().equals("cid"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversational.ActionConstants;
import org.opensearch.rest.RestChannel;
Expand All @@ -36,8 +39,17 @@
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;

import com.google.gson.Gson;

public class CreateInteractionRestActionTests extends OpenSearchTestCase {

Gson gson;

@Before
public void setup() {
gson = new Gson();
}

public void testBasics() {
CreateInteractionRestAction action = new CreateInteractionRestAction();
assert (action.getName().equals("conversational_create_interaction"));
Expand All @@ -49,8 +61,6 @@ public void testBasics() {
public void testPrepareRequest() throws Exception {
Map<String, String> params = Map
.of(
ActionConstants.CONVERSATION_ID_FIELD,
"cid",
ActionConstants.INPUT_FIELD,
"input",
ActionConstants.PROMPT_TEMPLATE_FIELD,
Expand All @@ -63,7 +73,10 @@ public void testPrepareRequest() throws Exception {
"metadata"
);
CreateInteractionRestAction action = new CreateInteractionRestAction();
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build();
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid"))
.withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON)
.build();

NodeClient client = mock(NodeClient.class);
RestChannel channel = mock(RestChannel.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ public void testCreateConversationNamed() throws IOException {
client(),
"POST",
ActionConstants.CREATE_CONVERSATION_REST_PATH,
Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name"),
"",
null,
gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")),
null
);
assert (response != null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,14 @@ public void testCreateInteraction() throws IOException {
"some metadata"
);
Response response = TestHelper
.makeRequest(client(), "POST", ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", id), params, "", null);
.makeRequest(
client(),
"POST",
ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", id),
null,
gson.toJson(params),
null
);
assert (response != null);
assert (TestHelper.restStatus(response) == RestStatus.OK);
HttpEntity httpEntity = response.getEntity();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ public void testDeleteConversation_WithInteractions() throws IOException {
client(),
"POST",
ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid),
params,
"",
null,
gson.toJson(params),
null
);
assert (ciresponse != null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ public void testGetInteractions_LastPage() throws IOException {
client(),
"POST",
ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid),
params,
"",
null,
gson.toJson(params),
null
);
assert (response != null);
Expand Down Expand Up @@ -154,8 +154,8 @@ public void testGetInteractions_MorePages() throws IOException {
client(),
"POST",
ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid),
params,
"",
null,
gson.toJson(params),
null
);
assert (response != null);
Expand Down Expand Up @@ -217,8 +217,8 @@ public void testGetInteractions_NextPage() throws IOException {
client(),
"POST",
ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid),
params,
"",
null,
gson.toJson(params),
null
);
assert (response != null);
Expand All @@ -234,8 +234,8 @@ public void testGetInteractions_NextPage() throws IOException {
client(),
"POST",
ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid),
params,
"",
null,
gson.toJson(params),
null
);
assert (response2 != null);
Expand Down

0 comments on commit 92328e8

Please sign in to comment.