diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 96431e7679..f87da7c433 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -54,9 +54,9 @@ public class ActionConstants { private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation"; /** path for create conversation */ public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create"; - /** path for list conversations */ + /** path for get conversations */ public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list"; - /** path for put interaction */ + /** path for create interaction */ public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create"; /** path for get interactions */ public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list"; @@ -66,6 +66,10 @@ public class ActionConstants { public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search"; /** path for search interactions */ public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search"; + /** path for get conversation */ + public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; + /** path for get interaction */ + public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/{interaction_id}"; /** default max results returned by get operations */ public final static int DEFAULT_MAX_RESULTS = 10; diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java index 024abe17ff..7a49d062d5 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java @@ -26,7 +26,7 @@ public class GetInteractionsAction extends ActionType { /** Instance of this */ public static final GetInteractionsAction INSTANCE = new GetInteractionsAction(); /** Name of this action */ - public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get"; + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/list"; private GetInteractionsAction() { super(NAME, GetInteractionsResponse::new); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 4c7f9a275a..7db504f905 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -133,8 +133,12 @@ import org.opensearch.ml.memory.action.conversation.CreateInteractionTransportAction; import org.opensearch.ml.memory.action.conversation.DeleteConversationAction; import org.opensearch.ml.memory.action.conversation.DeleteConversationTransportAction; +import org.opensearch.ml.memory.action.conversation.GetConversationAction; +import org.opensearch.ml.memory.action.conversation.GetConversationTransportAction; import org.opensearch.ml.memory.action.conversation.GetConversationsAction; import org.opensearch.ml.memory.action.conversation.GetConversationsTransportAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionTransportAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; @@ -172,7 +176,9 @@ import org.opensearch.ml.rest.RestMemoryCreateConversationAction; import org.opensearch.ml.rest.RestMemoryCreateInteractionAction; import org.opensearch.ml.rest.RestMemoryDeleteConversationAction; +import org.opensearch.ml.rest.RestMemoryGetConversationAction; import org.opensearch.ml.rest.RestMemoryGetConversationsAction; +import org.opensearch.ml.rest.RestMemoryGetInteractionAction; import org.opensearch.ml.rest.RestMemoryGetInteractionsAction; import org.opensearch.ml.rest.RestMemorySearchConversationsAction; import org.opensearch.ml.rest.RestMemorySearchInteractionsAction; @@ -302,7 +308,9 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(GetInteractionsAction.INSTANCE, GetInteractionsTransportAction.class), new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class), new ActionHandler<>(SearchInteractionsAction.INSTANCE, SearchInteractionsTransportAction.class), - new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class) + new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), + new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), + new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class) ); } @@ -554,6 +562,8 @@ public List getRestHandlers( RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction(); RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction(); RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); + RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); + RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); return ImmutableList .of( restMLStatsAction, @@ -587,7 +597,9 @@ public List getRestHandlers( restListInteractionsAction, restDeleteConversationAction, restSearchConversationsAction, - restSearchInteractionsAction + restSearchInteractionsAction, + restGetConversationAction, + restGetInteractionAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java new file mode 100644 index 0000000000..dbabd40953 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetConversationAction; +import org.opensearch.ml.memory.action.conversation.GetConversationRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetConversationAction extends BaseRestHandler { + private final static String GET_CONVERSATION_NAME = "conversational_get_conversation"; + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATION_REST_PATH)); + } + + @Override + public String getName() { + return GET_CONVERSATION_NAME; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetConversationRequest gcRequest = GetConversationRequest.fromRestRequest(request); + return channel -> client.execute(GetConversationAction.INSTANCE, gcRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java new file mode 100644 index 0000000000..ad2b35dbf6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetInteractionAction extends BaseRestHandler { + private final static String GET_INTERACTION_NAME = "conversational_get_interaction"; + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTION_REST_PATH)); + } + + @Override + public String getName() { + return GET_INTERACTION_NAME; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetInteractionRequest giRequest = GetInteractionRequest.fromRestRequest(request); + return channel -> client.execute(GetInteractionAction.INSTANCE, giRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java new file mode 100644 index 0000000000..5a55b1c301 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetConversationActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testGetConversation() throws IOException { + Response ccresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_CONVERSATION_REST_PATH, + null, + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")), + null + ); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + @SuppressWarnings("unchecked") + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD)); + String id = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD); + + Response gcresponse = TestHelper + .makeRequest(client(), "GET", ActionConstants.GET_CONVERSATION_REST_PATH.replace("{conversation_id}", id), null, "", null); + assert (gcresponse != null); + assert (TestHelper.restStatus(gcresponse) == RestStatus.OK); + HttpEntity gchttpEntity = gcresponse.getEntity(); + String gcentitiyString = TestHelper.httpEntityToString(gchttpEntity); + @SuppressWarnings("unchecked") + Map gcmap = gson.fromJson(gcentitiyString, Map.class); + assert (gcmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD) && gcmap.get(ActionConstants.CONVERSATION_ID_FIELD).equals(id)); + assert (gcmap.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) + && gcmap.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD).equals("name")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java new file mode 100644 index 0000000000..0e81f2dacb --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetConversationAction; +import org.opensearch.ml.memory.action.conversation.GetConversationRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetConversationActionTests extends OpenSearchTestCase { + public void testBasics() { + RestMemoryGetConversationAction action = new RestMemoryGetConversationAction(); + assert (action.getName().equals("conversational_get_conversation")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATION_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetConversationAction action = new RestMemoryGetConversationAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid")) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationRequest.class); + verify(client, times(1)).execute(eq(GetConversationAction.INSTANCE), argCaptor.capture(), any()); + assert (argCaptor.getValue().getConversationId().equals("cid")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java new file mode 100644 index 0000000000..691195a99b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetInteractionActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testGetInteraction() throws IOException { + Response ccresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_CONVERSATION_REST_PATH, + null, + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")), + null + ); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + @SuppressWarnings("unchecked") + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD)); + String cid = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD); + + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "some metadata" + ); + Response ciresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params), + null + ); + assert (ciresponse != null); + assert (TestHelper.restStatus(ciresponse) == RestStatus.OK); + HttpEntity cihttpEntity = ciresponse.getEntity(); + String cientityString = TestHelper.httpEntityToString(cihttpEntity); + @SuppressWarnings("unchecked") + Map cimap = gson.fromJson(cientityString, Map.class); + assert (cimap.containsKey(ActionConstants.RESPONSE_INTERACTION_ID_FIELD)); + String iid = cimap.get(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); + + Response giresponse = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_INTERACTION_REST_PATH.replace("{conversation_id}", cid).replace("{interaction_id}", iid), + null, + "", + null + ); + assert (giresponse != null); + assert (TestHelper.restStatus(giresponse) == RestStatus.OK); + HttpEntity gihttpEntity = giresponse.getEntity(); + String gientityString = TestHelper.httpEntityToString(gihttpEntity); + @SuppressWarnings("unchecked") + Map gimap = gson.fromJson(gientityString, Map.class); + assert (gimap.containsKey(ActionConstants.RESPONSE_INTERACTION_ID_FIELD) + && gimap.get(ActionConstants.RESPONSE_INTERACTION_ID_FIELD).equals(iid)); + assert (gimap.containsKey(ActionConstants.CONVERSATION_ID_FIELD) && gimap.get(ActionConstants.CONVERSATION_ID_FIELD).equals(cid)); + assert (gimap.containsKey(ActionConstants.INPUT_FIELD) && gimap.get(ActionConstants.INPUT_FIELD).equals("input")); + assert (gimap.containsKey(ActionConstants.PROMPT_TEMPLATE_FIELD) + && gimap.get(ActionConstants.PROMPT_TEMPLATE_FIELD).equals("promtp template")); + assert (gimap.containsKey(ActionConstants.AI_RESPONSE_FIELD) && gimap.get(ActionConstants.AI_RESPONSE_FIELD).equals("response")); + assert (gimap.containsKey(ActionConstants.RESPONSE_ORIGIN_FIELD) + && gimap.get(ActionConstants.RESPONSE_ORIGIN_FIELD).equals("origin")); + assert (gimap.containsKey(ActionConstants.ADDITIONAL_INFO_FIELD) + && gimap.get(ActionConstants.ADDITIONAL_INFO_FIELD).equals("some metadata")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java new file mode 100644 index 0000000000..9d0cc6515b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetInteractionActionTests extends OpenSearchTestCase { + public void testBasics() { + RestMemoryGetInteractionAction action = new RestMemoryGetInteractionAction(); + assert (action.getName().equals("conversational_get_interaction")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTION_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetInteractionAction action = new RestMemoryGetInteractionAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid", ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid")) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionRequest.class); + verify(client, times(1)).execute(eq(GetInteractionAction.INSTANCE), argCaptor.capture(), any()); + assert (argCaptor.getValue().getConversationId().equals("cid")); + assert (argCaptor.getValue().getInteractionId().equals("iid")); + } +}