Skip to content

Commit

Permalink
fix feature flag with updateConsumer
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Sep 1, 2023
1 parent 6927c45 commit 85a2901
Show file tree
Hide file tree
Showing 17 changed files with 222 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.opensearch.ml.common.conversation;

import org.opensearch.common.settings.Setting;

/**
* Class containing a bunch of constant defining how the conversational indices are formatted
*/
Expand Down Expand Up @@ -97,6 +99,7 @@ public class ConversationalIndexConstants {
+ " }\n"
+ "}";

/** Name of the feature flag for conversational memory */
public final static String MEMORY_FEATURE_FLAG_NAME = "plugins.ml_commons.memory_feature_enabled";
/** Feature Flag setting for conversational memory */
public static final Setting<Boolean> ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting
.boolSetting("plugins.ml_commons.memory_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
Expand All @@ -41,14 +42,16 @@ public class CreateConversationTransportAction extends HandledTransportAction<Cr

private ConversationalMemoryHandler cmHandler;
private Client client;
private ClusterService clusterService;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public CreateConversationTransportAction(
Expand All @@ -61,17 +64,22 @@ public CreateConversationTransportAction(
super(CreateConversationAction.NAME, transportService, actionFilters, CreateConversationRequest::new);
this.cmHandler = cmHandler;
this.client = client;
this.clusterService = clusterService;
@SuppressWarnings("unchecked")
Setting<Boolean> setting = (Setting<Boolean>) clusterService
.getClusterSettings()
.get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey());
this.featureIsEnabled = setting.get(clusterService.getSettings());
clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, CreateConversationRequest request, ActionListener<CreateConversationResponse> actionListener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
Expand All @@ -41,14 +42,16 @@ public class CreateInteractionTransportAction extends HandledTransportAction<Cre

private ConversationalMemoryHandler cmHandler;
private Client client;
private ClusterService clusterService;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for doing intra-cluster communication
* @param actionFilters for filtering actions
* @param cmHandler handler for conversational memory
* @param client client for general opensearch ops
* @param clusterService for some cluster ops
*/
@Inject
public CreateInteractionTransportAction(
Expand All @@ -61,17 +64,22 @@ public CreateInteractionTransportAction(
super(CreateInteractionAction.NAME, transportService, actionFilters, CreateInteractionRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.clusterService = clusterService;
@SuppressWarnings("unchecked")
Setting<Boolean> setting = (Setting<Boolean>) clusterService
.getClusterSettings()
.get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey());
this.featureIsEnabled = setting.get(clusterService.getSettings());
clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, CreateInteractionRequest request, ActionListener<CreateInteractionResponse> actionListener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
Expand All @@ -41,14 +42,16 @@ public class DeleteConversationTransportAction extends HandledTransportAction<De

private Client client;
private ConversationalMemoryHandler cmHandler;
private ClusterService clusterService;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public DeleteConversationTransportAction(
Expand All @@ -61,17 +64,22 @@ public DeleteConversationTransportAction(
super(DeleteConversationAction.NAME, transportService, actionFilters, DeleteConversationRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.clusterService = clusterService;
@SuppressWarnings("unchecked")
Setting<Boolean> setting = (Setting<Boolean>) clusterService
.getClusterSettings()
.get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey());
this.featureIsEnabled = setting.get(clusterService.getSettings());
clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it);
}

@Override
public void doExecute(Task task, DeleteConversationRequest request, ActionListener<DeleteConversationResponse> listener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
if (!featureIsEnabled) {
listener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationMeta;
Expand All @@ -44,14 +45,16 @@ public class GetConversationsTransportAction extends HandledTransportAction<GetC

private Client client;
private ConversationalMemoryHandler cmHandler;
private ClusterService clusterService;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public GetConversationsTransportAction(
Expand All @@ -64,17 +67,22 @@ public GetConversationsTransportAction(
super(GetConversationsAction.NAME, transportService, actionFilters, GetConversationsRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.clusterService = clusterService;
@SuppressWarnings("unchecked")
Setting<Boolean> setting = (Setting<Boolean>) clusterService
.getClusterSettings()
.get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey());
this.featureIsEnabled = setting.get(clusterService.getSettings());
clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it);
}

@Override
public void doExecute(Task task, GetConversationsRequest request, ActionListener<GetConversationsResponse> actionListener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
Expand All @@ -44,14 +45,16 @@ public class GetInteractionsTransportAction extends HandledTransportAction<GetIn

private Client client;
private ConversationalMemoryHandler cmHandler;
private ClusterService clusterService;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public GetInteractionsTransportAction(
Expand All @@ -64,17 +67,22 @@ public GetInteractionsTransportAction(
super(GetInteractionsAction.NAME, transportService, actionFilters, GetInteractionsRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.clusterService = clusterService;
@SuppressWarnings("unchecked")
Setting<Boolean> setting = (Setting<Boolean>) clusterService
.getClusterSettings()
.get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey());
this.featureIsEnabled = setting.get(clusterService.getSettings());
clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it);
}

@Override
public void doExecute(Task task, GetInteractionsRequest request, ActionListener<GetInteractionsResponse> actionListener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Set;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -33,6 +34,7 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -90,13 +92,16 @@ public void setup() throws IOException {
this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class);

this.request = new CreateConversationRequest("test");
this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build();
Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
this.threadContext = new ThreadContext(settings);
when(this.client.threadPool()).thenReturn(this.threadPool);
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
when(this.clusterService.getSettings()).thenReturn(settings);
when(this.clusterService.getClusterSettings())
.thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED)));

this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService));
}

public void testCreateConversation() {
Expand Down Expand Up @@ -148,6 +153,9 @@ public void testDoExecuteFails_thenFail() {

public void testFeatureDisabled_ThenFail() {
when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY);
when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED)));
this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

action.doExecute(null, request, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Set;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -33,6 +34,7 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -90,13 +92,17 @@ public void setup() throws IOException {
this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class);

this.request = new CreateInteractionRequest("test-cid", "input", "pt", "response", "origin", "metadata");
this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build();
Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
this.threadContext = new ThreadContext(settings);
when(this.client.threadPool()).thenReturn(this.threadPool);
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
when(this.clusterService.getSettings()).thenReturn(settings);
when(this.clusterService.getClusterSettings())
.thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED)));

this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

}

public void testCreateInteraction() {
Expand Down Expand Up @@ -138,6 +144,9 @@ public void testDoExecuteFails_thenFail() {

public void testFeatureDisabled_ThenFail() {
when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY);
when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED)));
this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

action.doExecute(null, request, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
Expand Down
Loading

0 comments on commit 85a2901

Please sign in to comment.