forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: HenryL27 <[email protected]>
- Loading branch information
Showing
16 changed files
with
1,135 additions
and
0 deletions.
There are no files selected for viewing
34 changes: 34 additions & 0 deletions
34
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.action.ActionType; | ||
|
||
/** | ||
* Action for retrieving a top-level conversation object by id | ||
*/ | ||
public class GetConversationAction extends ActionType<GetConversationResponse> { | ||
/** Instance of this */ | ||
public static final GetConversationAction INSTANCE = new GetConversationAction(); | ||
/** Name of this action */ | ||
public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/get"; | ||
|
||
private GetConversationAction() { | ||
super(NAME, GetConversationResponse::new); | ||
} | ||
} |
77 changes: 77 additions & 0 deletions
77
...ry/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
import java.io.IOException; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.ml.common.conversation.ActionConstants; | ||
import org.opensearch.rest.RestRequest; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Getter; | ||
|
||
/** | ||
* Action Request object for GetConversation (singular) | ||
*/ | ||
@AllArgsConstructor | ||
public class GetConversationRequest extends ActionRequest { | ||
@Getter | ||
private String conversationId; | ||
|
||
/** | ||
* Stream Constructor | ||
* @param in input stream to read this from | ||
* @throws IOException if something goes wrong reading from stream | ||
*/ | ||
public GetConversationRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.conversationId = in.readString(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.conversationId); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
if (this.conversationId == null) { | ||
exception = addValidationError("GetConversation Request must have a conversation id", exception); | ||
} | ||
return exception; | ||
} | ||
|
||
/** | ||
* Creates a GetConversationRequest from a rest request | ||
* @param request Rest Request representing a GetConversationRequest | ||
* @return the new GetConversationRequest | ||
* @throws IOException if something goes wrong in translation | ||
*/ | ||
public static GetConversationRequest fromRestRequest(RestRequest request) throws IOException { | ||
String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); | ||
return new GetConversationRequest(conversationId); | ||
} | ||
} |
60 changes: 60 additions & 0 deletions
60
...y/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import java.io.IOException; | ||
|
||
import org.opensearch.core.action.ActionResponse; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.ml.common.conversation.ConversationMeta; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Getter; | ||
|
||
/** | ||
* ActionResponse object for GetConversation (singular) | ||
*/ | ||
@AllArgsConstructor | ||
public class GetConversationResponse extends ActionResponse implements ToXContentObject { | ||
|
||
@Getter | ||
private ConversationMeta conversation; | ||
|
||
/** | ||
* Stream Constructor | ||
* @param in input stream to read this from | ||
* @throws IOException if soething goes wrong in reading | ||
*/ | ||
public GetConversationResponse(StreamInput in) throws IOException { | ||
super(in); | ||
this.conversation = ConversationMeta.fromStream(in); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
this.conversation.writeTo(out); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
return this.conversation.toXContent(builder, params); | ||
} | ||
} |
100 changes: 100 additions & 0 deletions
100
...ain/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.OpenSearchException; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.common.conversation.ConversationMeta; | ||
import org.opensearch.ml.common.conversation.ConversationalIndexConstants; | ||
import org.opensearch.ml.memory.ConversationalMemoryHandler; | ||
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* Transport Action for GetConversation | ||
*/ | ||
@Log4j2 | ||
public class GetConversationTransportAction extends HandledTransportAction<GetConversationRequest, GetConversationResponse> { | ||
private Client client; | ||
private ConversationalMemoryHandler cmHandler; | ||
|
||
private volatile boolean featureIsEnabled; | ||
|
||
/** | ||
* Constructor | ||
* @param transportService for inter-node communications | ||
* @param actionFilters for filtering actions | ||
* @param cmHandler Handler for conversational memory operations | ||
* @param client OS Client for dealing with OS | ||
* @param clusterService for some cluster ops | ||
*/ | ||
@Inject | ||
public GetConversationTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
OpenSearchConversationalMemoryHandler cmHandler, | ||
Client client, | ||
ClusterService clusterService | ||
) { | ||
super(GetConversationAction.NAME, transportService, actionFilters, GetConversationRequest::new); | ||
this.client = client; | ||
this.cmHandler = cmHandler; | ||
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); | ||
clusterService | ||
.getClusterSettings() | ||
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); | ||
} | ||
|
||
@Override | ||
public void doExecute(Task task, GetConversationRequest request, ActionListener<GetConversationResponse> actionListener) { | ||
if (!featureIsEnabled) { | ||
actionListener | ||
.onFailure( | ||
new OpenSearchException( | ||
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting " | ||
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() | ||
) | ||
); | ||
return; | ||
} else { | ||
String conversationId = request.getConversationId(); | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { | ||
ActionListener<GetConversationResponse> internalListener = ActionListener | ||
.runBefore(actionListener, () -> context.restore()); | ||
ActionListener<ConversationMeta> al = ActionListener.wrap(conversationMeta -> { | ||
internalListener.onResponse(new GetConversationResponse(conversationMeta)); | ||
}, e -> { internalListener.onFailure(e); }); | ||
cmHandler.getConversation(conversationId, al); | ||
} catch (Exception e) { | ||
log.error("Failed to get Conversation " + conversationId, e); | ||
actionListener.onFailure(e); | ||
} | ||
|
||
} | ||
|
||
} | ||
} |
34 changes: 34 additions & 0 deletions
34
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.action.ActionType; | ||
|
||
/** | ||
* Action for Get Interaction (singular) | ||
*/ | ||
public class GetInteractionAction extends ActionType<GetInteractionResponse> { | ||
/** Instance of this */ | ||
public static final GetInteractionAction INSTANCE = new GetInteractionAction(); | ||
/** Name of this action */ | ||
public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get"; | ||
|
||
private GetInteractionAction() { | ||
super(NAME, GetInteractionResponse::new); | ||
} | ||
} |
85 changes: 85 additions & 0 deletions
85
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
import java.io.IOException; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.ml.common.conversation.ActionConstants; | ||
import org.opensearch.rest.RestRequest; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Getter; | ||
|
||
/** | ||
* Action Request for GetInteraction | ||
*/ | ||
@AllArgsConstructor | ||
public class GetInteractionRequest extends ActionRequest { | ||
@Getter | ||
private String conversationId; | ||
@Getter | ||
private String interactionId; | ||
|
||
/** | ||
* Stream Constructor | ||
* @param in input stream to read this request from | ||
* @throws IOException if somthing goes wrong reading | ||
*/ | ||
public GetInteractionRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.conversationId = in.readString(); | ||
this.interactionId = in.readString(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.conversationId); | ||
out.writeString(this.interactionId); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
if (conversationId == null) { | ||
exception = addValidationError("Get Interaction Request must have a conversation id", exception); | ||
} | ||
if (interactionId == null) { | ||
exception = addValidationError("Get Interaction Request must have an interaction id", exception); | ||
} | ||
return exception; | ||
} | ||
|
||
/** | ||
* Creates a GetInteractionRequest from a Rest Request | ||
* @param request Rest Request representing a GetInteractionRequest | ||
* @return new GetInteractionRequest built from the rest request | ||
* @throws IOException if something goes wrong reading from the rest request | ||
*/ | ||
public static GetInteractionRequest fromRestRequest(RestRequest request) throws IOException { | ||
String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); | ||
String interactionId = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); | ||
return new GetInteractionRequest(conversationId, interactionId); | ||
} | ||
} |
Oops, something went wrong.