From f8330ee6a965962e017ebbae1b570fb479137854 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 10 Oct 2023 14:26:40 -0700 Subject: [PATCH] Add singular get transport layer Signed-off-by: HenryL27 --- .../conversation/GetConversationAction.java | 34 ++++ .../conversation/GetConversationRequest.java | 77 +++++++++ .../conversation/GetConversationResponse.java | 60 +++++++ .../GetConversationTransportAction.java | 100 +++++++++++ .../conversation/GetInteractionAction.java | 34 ++++ .../conversation/GetInteractionRequest.java | 85 ++++++++++ .../conversation/GetInteractionResponse.java | 61 +++++++ .../GetInteractionTransportAction.java | 95 +++++++++++ .../conversation/ConversationActionTests.java | 1 + .../GetConversationRequestTests.java | 68 ++++++++ .../GetConversationResponseTests.java | 62 +++++++ .../GetConversationTransportActionTests.java | 150 +++++++++++++++++ .../GetInteractionRequestTests.java | 85 ++++++++++ .../GetInteractionResponseTests.java | 64 +++++++ .../GetInteractionTransportActionTests.java | 158 ++++++++++++++++++ .../conversation/InteractionActionTests.java | 1 + 16 files changed, 1135 insertions(+) create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java new file mode 100644 index 0000000000..7839915201 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for retrieving a top-level conversation object by id + */ +public class GetConversationAction extends ActionType { + /** Instance of this */ + public static final GetConversationAction INSTANCE = new GetConversationAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/get"; + + private GetConversationAction() { + super(NAME, GetConversationResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java new file mode 100644 index 0000000000..c5a6f6dd0e --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Request object for GetConversation (singular) + */ +@AllArgsConstructor +public class GetConversationRequest extends ActionRequest { + @Getter + private String conversationId; + + /** + * Stream Constructor + * @param in input stream to read this from + * @throws IOException if something goes wrong reading from stream + */ + public GetConversationRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.conversationId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (this.conversationId == null) { + exception = addValidationError("GetConversation Request must have a conversation id", exception); + } + return exception; + } + + /** + * Creates a GetConversationRequest from a rest request + * @param request Rest Request representing a GetConversationRequest + * @return the new GetConversationRequest + * @throws IOException if something goes wrong in translation + */ + public static GetConversationRequest fromRestRequest(RestRequest request) throws IOException { + String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); + return new GetConversationRequest(conversationId); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java new file mode 100644 index 0000000000..b757723e09 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java @@ -0,0 +1,60 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ConversationMeta; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * ActionResponse object for GetConversation (singular) + */ +@AllArgsConstructor +public class GetConversationResponse extends ActionResponse implements ToXContentObject { + + @Getter + private ConversationMeta conversation; + + /** + * Stream Constructor + * @param in input stream to read this from + * @throws IOException if soething goes wrong in reading + */ + public GetConversationResponse(StreamInput in) throws IOException { + super(in); + this.conversation = ConversationMeta.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + this.conversation.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return this.conversation.toXContent(builder, params); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java new file mode 100644 index 0000000000..0f1c70ad51 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java @@ -0,0 +1,100 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +/** + * Transport Action for GetConversation + */ +@Log4j2 +public class GetConversationTransportAction extends HandledTransportAction { + private Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public GetConversationTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(GetConversationAction.NAME, transportService, actionFilters, GetConversationRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, GetConversationRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } else { + String conversationId = request.getConversationId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener + .runBefore(actionListener, () -> context.restore()); + ActionListener al = ActionListener.wrap(conversationMeta -> { + internalListener.onResponse(new GetConversationResponse(conversationMeta)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getConversation(conversationId, al); + } catch (Exception e) { + log.error("Failed to get Conversation " + conversationId, e); + actionListener.onFailure(e); + } + + } + + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java new file mode 100644 index 0000000000..adaffcb9dc --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for Get Interaction (singular) + */ +public class GetInteractionAction extends ActionType { + /** Instance of this */ + public static final GetInteractionAction INSTANCE = new GetInteractionAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get"; + + private GetInteractionAction() { + super(NAME, GetInteractionResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java new file mode 100644 index 0000000000..6808857c40 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Request for GetInteraction + */ +@AllArgsConstructor +public class GetInteractionRequest extends ActionRequest { + @Getter + private String conversationId; + @Getter + private String interactionId; + + /** + * Stream Constructor + * @param in input stream to read this request from + * @throws IOException if somthing goes wrong reading + */ + public GetInteractionRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + this.interactionId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.conversationId); + out.writeString(this.interactionId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (conversationId == null) { + exception = addValidationError("Get Interaction Request must have a conversation id", exception); + } + if (interactionId == null) { + exception = addValidationError("Get Interaction Request must have an interaction id", exception); + } + return exception; + } + + /** + * Creates a GetInteractionRequest from a Rest Request + * @param request Rest Request representing a GetInteractionRequest + * @return new GetInteractionRequest built from the rest request + * @throws IOException if something goes wrong reading from the rest request + */ + public static GetInteractionRequest fromRestRequest(RestRequest request) throws IOException { + String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); + String interactionId = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); + return new GetInteractionRequest(conversationId, interactionId); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java new file mode 100644 index 0000000000..7d3a1f3c73 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java @@ -0,0 +1,61 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * ActionResponse for Get Interaction (sg) + */ +@AllArgsConstructor +public class GetInteractionResponse extends ActionResponse implements ToXContentObject { + + @Getter + private Interaction interaction; + + /** + * Stream Constructor + * @param in Stream Input to read this response from + * @throws IOException if something goes wrong reading from stream + */ + public GetInteractionResponse(StreamInput in) throws IOException { + super(in); + this.interaction = Interaction.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + interaction.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return this.interaction.toXContent(builder, params); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java new file mode 100644 index 0000000000..16205ec8b9 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetInteractionTransportAction extends HandledTransportAction { + + private Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public GetInteractionTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(GetInteractionAction.NAME, transportService, actionFilters, GetInteractionRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, GetInteractionRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } + String conversationId = request.getConversationId(); + String interactionId = request.getInteractionId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener al = ActionListener.wrap(interaction -> { + internalListener.onResponse(new GetInteractionResponse(interaction)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getInteraction(conversationId, interactionId, al); + } catch (Exception e) { + log.error("Failed to get interaction " + interactionId + " in conversation " + conversationId, e); + actionListener.onFailure(e); + } + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java index 65cbb7dfea..541ff5ed2e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java @@ -25,5 +25,6 @@ public void testActions() { assert (DeleteConversationAction.INSTANCE instanceof DeleteConversationAction); assert (GetConversationsAction.INSTANCE instanceof GetConversationsAction); assert (SearchConversationsAction.INSTANCE instanceof SearchConversationsAction); + assert (GetConversationAction.INSTANCE instanceof GetConversationAction); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java new file mode 100644 index 0000000000..cb8b67b44b --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetConversationRequestTests extends OpenSearchTestCase { + + public void testConstructorAndStreaming() throws IOException { + GetConversationRequest request = new GetConversationRequest("Test-id"); + assert (request.validate() == null); + assert (request.getConversationId().equals("Test-id")); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetConversationRequest newRequest = new GetConversationRequest(in); + assert (newRequest.validate() == null); + assert (newRequest.getConversationId().equals("Test-id")); + } + + public void testNullConvoId_ThenFail() { + String id = null; + GetConversationRequest request = new GetConversationRequest(id); + ActionRequestValidationException exc = request.validate(); + assert (exc != null); + assert (exc.validationErrors().size() == 1); + assert (exc.validationErrors().get(0).equals("GetConversation Request must have a conversation id")); + } + + public void testFromRestRequest() throws IOException { + Map params = Map.of(ActionConstants.CONVERSATION_ID_FIELD, "testcid"); + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + GetConversationRequest request = GetConversationRequest.fromRestRequest(rrequest); + assert (request.validate() == null); + assert (request.getConversationId().equals("testcid")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java new file mode 100644 index 0000000000..4b8f3a8fed --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.test.OpenSearchTestCase; + +public class GetConversationResponseTests extends OpenSearchTestCase { + + public void testGetConversationResponseStreaming() throws IOException { + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null); + GetConversationResponse response = new GetConversationResponse(convo); + assert (response.getConversation().equals(convo)); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetConversationResponse newResponse = new GetConversationResponse(in); + assert (newResponse.getConversation().equals(convo)); + } + + public void testToXContent() throws IOException { + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null); + GetConversationResponse response = new GetConversationResponse(convo); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"conversation_id\":\"cid\",\"create_time\":\"" + convo.getCreatedTime() + "\",\"name\":\"name\"}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + assert (ld.getDistance(result, expected) > 0.95); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java new file mode 100644 index 0000000000..3afcc1dd21 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java @@ -0,0 +1,150 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetConversationTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetConversationRequest request; + GetConversationTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetConversationRequest("test-cid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testGetConversation() { + ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), "name", null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(result); + return null; + }).when(cmHandler).getConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getConversation().getId().equals("test-cid")); + } + + public void testGetConversationFails_ThenFail() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("CMHandler Failure")); + return null; + }).when(cmHandler).getConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("CMHandler Failure")); + } + + public void testHandlerThrows_ThenFail() { + doThrow(new RuntimeException("CMHandler Throws")).when(cmHandler).getConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("CMHandler Throws")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java new file mode 100644 index 0000000000..678004ae09 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetInteractionRequestTests extends OpenSearchTestCase { + + public void testConstructorAndStreaming() throws IOException { + GetInteractionRequest request = new GetInteractionRequest("cid", "iid"); + assert (request.validate() == null); + assert (request.getConversationId().equals("cid")); + assert (request.getInteractionId().equals("iid")); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetInteractionRequest newRequest = new GetInteractionRequest(in); + assert (newRequest.validate() == null); + assert (newRequest.getConversationId().equals("cid")); + assert (newRequest.getInteractionId().equals("iid")); + } + + public void testMalformedRequest_ThenInvalid() { + GetInteractionRequest bad1 = new GetInteractionRequest(null, "iid"); + GetInteractionRequest bad2 = new GetInteractionRequest("cid", null); + GetInteractionRequest bad3 = new GetInteractionRequest(null, null); + ActionRequestValidationException exc1 = bad1.validate(); + ActionRequestValidationException exc2 = bad2.validate(); + ActionRequestValidationException exc3 = bad3.validate(); + + assert (exc1 != null); + assert (exc1.validationErrors().size() == 1); + assert (exc1.validationErrors().get(0).equals("Get Interaction Request must have a conversation id")); + + assert (exc2 != null); + assert (exc2.validationErrors().size() == 1); + assert (exc2.validationErrors().get(0).equals("Get Interaction Request must have an interaction id")); + + assert (exc3 != null); + assert (exc3.validationErrors().size() == 2); + assert (exc3.validationErrors().get(0).equals("Get Interaction Request must have a conversation id")); + assert (exc3.validationErrors().get(1).equals("Get Interaction Request must have an interaction id")); + } + + public void testFromRestRequest() throws IOException { + Map params = Map + .of(ActionConstants.CONVERSATION_ID_FIELD, "testcid", ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "testiid"); + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + GetInteractionRequest request = GetInteractionRequest.fromRestRequest(rrequest); + assert (request.validate() == null); + assert (request.getConversationId().equals("testcid")); + assert (request.getInteractionId().equals("testiid")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java new file mode 100644 index 0000000000..b7cbc1c471 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + +public class GetInteractionResponseTests extends OpenSearchTestCase { + + public void testConstructorAndStreaming() throws IOException { + Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + GetInteractionResponse response = new GetInteractionResponse(interaction); + assert (response.getInteraction().equals(interaction)); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetInteractionResponse newResponse = new GetInteractionResponse(in); + assert (newResponse.getInteraction().equals(interaction)); + } + + public void testToXContent() throws IOException { + Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + GetInteractionResponse response = new GetInteractionResponse(interaction); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"conversation_id\":\"cid\",\"interaction_id\":\"iid\",\"create_time\":\"" + + interaction.getCreateTime() + + "\",\"input\":\"inp\",\"prompt_template\":\"pt\",\"response\":\"rsp\",\"origin\":\"ogn\",\"additional_info\":\"extra\"}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + assert (ld.getDistance(result, expected) > 0.95); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java new file mode 100644 index 0000000000..6ca8197b54 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java @@ -0,0 +1,158 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetInteractionTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetInteractionRequest request; + GetInteractionTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetInteractionRequest("cid", "iid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testGetInteraction() { + Interaction testInteraction = new Interaction( + "iid", + Instant.now(), + "cid", + "test-input", + "pt", + "test-response", + "test-origin", + "metadata" + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(testInteraction); + return null; + }).when(cmHandler).getInteraction(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getInteraction().getId().equals("iid")); + } + + public void testGetInteractionFails_ThenFail() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new Exception("Storage layer failure")); + return null; + }).when(cmHandler).getInteraction(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Storage layer failure")); + } + + public void testHandlerThrows_ThenFail() { + doThrow(new RuntimeException("CMHandler Failure")).when(cmHandler).getInteraction(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("CMHandler Failure")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java index 89a6fae6a3..187ae4bdf7 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java @@ -24,5 +24,6 @@ public void testActions() { assert (CreateInteractionAction.INSTANCE instanceof CreateInteractionAction); assert (GetInteractionsAction.INSTANCE instanceof GetInteractionsAction); assert (SearchInteractionsAction.INSTANCE instanceof SearchInteractionsAction); + assert (GetInteractionAction.INSTANCE instanceof GetInteractionAction); } }