Skip to content

Commit

Permalink
add singular get rest actions
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Oct 11, 2023
1 parent f8330ee commit e379420
Show file tree
Hide file tree
Showing 9 changed files with 460 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ public class ActionConstants {
private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation";
/** path for create conversation */
public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create";
/** path for list conversations */
/** path for get conversations */
public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list";
/** path for put interaction */
/** path for create interaction */
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create";
/** path for get interactions */
public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list";
Expand All @@ -66,6 +66,10 @@ public class ActionConstants {
public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search";
/** path for search interactions */
public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search";
/** path for get conversation */
public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}";
/** path for get interaction */
public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/{interaction_id}";

/** default max results returned by get operations */
public final static int DEFAULT_MAX_RESULTS = 10;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class GetInteractionsAction extends ActionType<GetInteractionsResponse> {
/** Instance of this */
public static final GetInteractionsAction INSTANCE = new GetInteractionsAction();
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get";
public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/list";

private GetInteractionsAction() {
super(NAME, GetInteractionsResponse::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@
import org.opensearch.ml.memory.action.conversation.CreateInteractionTransportAction;
import org.opensearch.ml.memory.action.conversation.DeleteConversationAction;
import org.opensearch.ml.memory.action.conversation.DeleteConversationTransportAction;
import org.opensearch.ml.memory.action.conversation.GetConversationAction;
import org.opensearch.ml.memory.action.conversation.GetConversationTransportAction;
import org.opensearch.ml.memory.action.conversation.GetConversationsAction;
import org.opensearch.ml.memory.action.conversation.GetConversationsTransportAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionTransportAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionsAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction;
import org.opensearch.ml.memory.action.conversation.SearchConversationsAction;
Expand Down Expand Up @@ -172,7 +176,9 @@
import org.opensearch.ml.rest.RestMemoryCreateConversationAction;
import org.opensearch.ml.rest.RestMemoryCreateInteractionAction;
import org.opensearch.ml.rest.RestMemoryDeleteConversationAction;
import org.opensearch.ml.rest.RestMemoryGetConversationAction;
import org.opensearch.ml.rest.RestMemoryGetConversationsAction;
import org.opensearch.ml.rest.RestMemoryGetInteractionAction;
import org.opensearch.ml.rest.RestMemoryGetInteractionsAction;
import org.opensearch.ml.rest.RestMemorySearchConversationsAction;
import org.opensearch.ml.rest.RestMemorySearchInteractionsAction;
Expand Down Expand Up @@ -302,7 +308,9 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc
new ActionHandler<>(GetInteractionsAction.INSTANCE, GetInteractionsTransportAction.class),
new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class),
new ActionHandler<>(SearchInteractionsAction.INSTANCE, SearchInteractionsTransportAction.class),
new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class)
new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class),
new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class),
new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class)
);
}

Expand Down Expand Up @@ -554,6 +562,8 @@ public List<RestHandler> getRestHandlers(
RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction();
RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction();
RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction();
RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction();
RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction();
return ImmutableList
.of(
restMLStatsAction,
Expand Down Expand Up @@ -587,7 +597,9 @@ public List<RestHandler> getRestHandlers(
restListInteractionsAction,
restDeleteConversationAction,
restSearchConversationsAction,
restSearchInteractionsAction
restSearchInteractionsAction,
restGetConversationAction,
restGetInteractionAction
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.rest;

import java.io.IOException;
import java.util.List;

import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.memory.action.conversation.GetConversationAction;
import org.opensearch.ml.memory.action.conversation.GetConversationRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import com.google.common.collect.ImmutableList;

public class RestMemoryGetConversationAction extends BaseRestHandler {
private final static String GET_CONVERSATION_NAME = "conversational_get_conversation";

@Override
public List<Route> routes() {
return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATION_REST_PATH));
}

@Override
public String getName() {
return GET_CONVERSATION_NAME;
}

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
GetConversationRequest gcRequest = GetConversationRequest.fromRestRequest(request);
return channel -> client.execute(GetConversationAction.INSTANCE, gcRequest, new RestToXContentListener<>(channel));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.rest;

import java.io.IOException;
import java.util.List;

import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import com.google.common.collect.ImmutableList;

public class RestMemoryGetInteractionAction extends BaseRestHandler {
private final static String GET_INTERACTION_NAME = "conversational_get_interaction";

@Override
public List<Route> routes() {
return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTION_REST_PATH));
}

@Override
public String getName() {
return GET_INTERACTION_NAME;
}

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
GetInteractionRequest giRequest = GetInteractionRequest.fromRestRequest(request);
return channel -> client.execute(GetInteractionAction.INSTANCE, giRequest, new RestToXContentListener<>(channel));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.rest;

import java.io.IOException;
import java.util.Map;

import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.utils.TestHelper;

import com.google.common.collect.ImmutableList;

public class RestMemoryGetConversationActionIT extends MLCommonsRestTestCase {
@Before
public void setupFeatureSettings() throws IOException {
Response response = TestHelper
.makeRequest(
client(),
"PUT",
"_cluster/settings",
null,
"{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}",
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);
assertEquals(200, response.getStatusLine().getStatusCode());
}

public void testGetConversation() throws IOException {
Response ccresponse = TestHelper
.makeRequest(
client(),
"POST",
ActionConstants.CREATE_CONVERSATION_REST_PATH,
null,
gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")),
null
);
assert (ccresponse != null);
assert (TestHelper.restStatus(ccresponse) == RestStatus.OK);
HttpEntity cchttpEntity = ccresponse.getEntity();
String ccentityString = TestHelper.httpEntityToString(cchttpEntity);
@SuppressWarnings("unchecked")
Map<String, String> ccmap = gson.fromJson(ccentityString, Map.class);
assert (ccmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD));
String id = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD);

Response gcresponse = TestHelper
.makeRequest(client(), "GET", ActionConstants.GET_CONVERSATION_REST_PATH.replace("{conversation_id}", id), null, "", null);
assert (gcresponse != null);
assert (TestHelper.restStatus(gcresponse) == RestStatus.OK);
HttpEntity gchttpEntity = gcresponse.getEntity();
String gcentitiyString = TestHelper.httpEntityToString(gchttpEntity);
@SuppressWarnings("unchecked")
Map<String, String> gcmap = gson.fromJson(gcentitiyString, Map.class);
assert (gcmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD) && gcmap.get(ActionConstants.CONVERSATION_ID_FIELD).equals(id));
assert (gcmap.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)
&& gcmap.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD).equals("name"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.util.List;
import java.util.Map;

import org.mockito.ArgumentCaptor;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.memory.action.conversation.GetConversationAction;
import org.opensearch.ml.memory.action.conversation.GetConversationRequest;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler.Route;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;

public class RestMemoryGetConversationActionTests extends OpenSearchTestCase {
public void testBasics() {
RestMemoryGetConversationAction action = new RestMemoryGetConversationAction();
assert (action.getName().equals("conversational_get_conversation"));
List<Route> routes = action.routes();
assert (routes.size() == 1);
assert (routes.get(0).equals(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATION_REST_PATH)));
}

public void testPrepareRequest() throws Exception {
RestMemoryGetConversationAction action = new RestMemoryGetConversationAction();
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid"))
.build();

NodeClient client = mock(NodeClient.class);
RestChannel channel = mock(RestChannel.class);
action.handleRequest(request, channel, client);

ArgumentCaptor<GetConversationRequest> argCaptor = ArgumentCaptor.forClass(GetConversationRequest.class);
verify(client, times(1)).execute(eq(GetConversationAction.INSTANCE), argCaptor.capture(), any());
assert (argCaptor.getValue().getConversationId().equals("cid"));
}
}
Loading

0 comments on commit e379420

Please sign in to comment.