Skip to content

Commit

Permalink
Add singular get transport layer
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Oct 11, 2023
1 parent f148f52 commit f8330ee
Show file tree
Hide file tree
Showing 16 changed files with 1,135 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionType;

/**
* Action for retrieving a top-level conversation object by id
*/
public class GetConversationAction extends ActionType<GetConversationResponse> {
/** Instance of this */
public static final GetConversationAction INSTANCE = new GetConversationAction();
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/get";

private GetConversationAction() {
super(NAME, GetConversationResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.IOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

import lombok.AllArgsConstructor;
import lombok.Getter;

/**
* Action Request object for GetConversation (singular)
*/
@AllArgsConstructor
public class GetConversationRequest extends ActionRequest {
@Getter
private String conversationId;

/**
* Stream Constructor
* @param in input stream to read this from
* @throws IOException if something goes wrong reading from stream
*/
public GetConversationRequest(StreamInput in) throws IOException {
super(in);
this.conversationId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.conversationId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (this.conversationId == null) {
exception = addValidationError("GetConversation Request must have a conversation id", exception);
}
return exception;
}

/**
* Creates a GetConversationRequest from a rest request
* @param request Rest Request representing a GetConversationRequest
* @return the new GetConversationRequest
* @throws IOException if something goes wrong in translation
*/
public static GetConversationRequest fromRestRequest(RestRequest request) throws IOException {
String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD);
return new GetConversationRequest(conversationId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import java.io.IOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ConversationMeta;

import lombok.AllArgsConstructor;
import lombok.Getter;

/**
* ActionResponse object for GetConversation (singular)
*/
@AllArgsConstructor
public class GetConversationResponse extends ActionResponse implements ToXContentObject {

@Getter
private ConversationMeta conversation;

/**
* Stream Constructor
* @param in input stream to read this from
* @throws IOException if soething goes wrong in reading
*/
public GetConversationResponse(StreamInput in) throws IOException {
super(in);
this.conversation = ConversationMeta.fromStream(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
this.conversation.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return this.conversation.toXContent(builder, params);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationMeta;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

/**
* Transport Action for GetConversation
*/
@Log4j2
public class GetConversationTransportAction extends HandledTransportAction<GetConversationRequest, GetConversationResponse> {
private Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public GetConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client,
ClusterService clusterService
) {
super(GetConversationAction.NAME, transportService, actionFilters, GetConversationRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
public void doExecute(Task task, GetConversationRequest request, ActionListener<GetConversationResponse> actionListener) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
} else {
String conversationId = request.getConversationId();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<GetConversationResponse> internalListener = ActionListener
.runBefore(actionListener, () -> context.restore());
ActionListener<ConversationMeta> al = ActionListener.wrap(conversationMeta -> {
internalListener.onResponse(new GetConversationResponse(conversationMeta));
}, e -> { internalListener.onFailure(e); });
cmHandler.getConversation(conversationId, al);
} catch (Exception e) {
log.error("Failed to get Conversation " + conversationId, e);
actionListener.onFailure(e);
}

}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionType;

/**
* Action for Get Interaction (singular)
*/
public class GetInteractionAction extends ActionType<GetInteractionResponse> {
/** Instance of this */
public static final GetInteractionAction INSTANCE = new GetInteractionAction();
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get";

private GetInteractionAction() {
super(NAME, GetInteractionResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.IOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

import lombok.AllArgsConstructor;
import lombok.Getter;

/**
* Action Request for GetInteraction
*/
@AllArgsConstructor
public class GetInteractionRequest extends ActionRequest {
@Getter
private String conversationId;
@Getter
private String interactionId;

/**
* Stream Constructor
* @param in input stream to read this request from
* @throws IOException if somthing goes wrong reading
*/
public GetInteractionRequest(StreamInput in) throws IOException {
super(in);
this.conversationId = in.readString();
this.interactionId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.conversationId);
out.writeString(this.interactionId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (conversationId == null) {
exception = addValidationError("Get Interaction Request must have a conversation id", exception);
}
if (interactionId == null) {
exception = addValidationError("Get Interaction Request must have an interaction id", exception);
}
return exception;
}

/**
* Creates a GetInteractionRequest from a Rest Request
* @param request Rest Request representing a GetInteractionRequest
* @return new GetInteractionRequest built from the rest request
* @throws IOException if something goes wrong reading from the rest request
*/
public static GetInteractionRequest fromRestRequest(RestRequest request) throws IOException {
String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD);
String interactionId = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD);
return new GetInteractionRequest(conversationId, interactionId);
}
}
Loading

0 comments on commit f8330ee

Please sign in to comment.