From 85a29017a5df7dc78e65799662b0afba48216d74 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 31 Aug 2023 17:34:33 -0700 Subject: [PATCH] fix feature flag with updateConsumer Signed-off-by: HenryL27 --- .../ConversationalIndexConstants.java | 7 +++++-- .../CreateConversationTransportAction.java | 16 ++++++++++---- .../CreateInteractionTransportAction.java | 16 ++++++++++---- .../DeleteConversationTransportAction.java | 16 ++++++++++---- .../GetConversationsTransportAction.java | 16 ++++++++++---- .../GetInteractionsTransportAction.java | 16 ++++++++++---- ...reateConversationTransportActionTests.java | 12 +++++++++-- ...CreateInteractionTransportActionTests.java | 13 ++++++++++-- ...eleteConversationTransportActionTests.java | 12 +++++++++-- .../GetConversationsTransportActionTests.java | 12 +++++++++-- .../GetInteractionsTransportActionTests.java | 12 +++++++++-- .../ml/settings/MLCommonsSettings.java | 3 +-- .../RestMemoryCreateConversationActionIT.java | 21 +++++++++++++++++++ .../RestMemoryCreateInteractionActionIT.java | 21 +++++++++++++++++++ .../RestMemoryDeleteConversationActionIT.java | 21 +++++++++++++++++++ .../RestMemoryGetConversationsActionIT.java | 21 +++++++++++++++++++ .../RestMemoryGetInteractionsActionIT.java | 21 +++++++++++++++++++ 17 files changed, 222 insertions(+), 34 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java index 051992636f..c8e652265b 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.common.conversation; +import org.opensearch.common.settings.Setting; + /** * Class containing a bunch of constant defining how the conversational indices are formatted */ @@ -97,6 +99,7 @@ public class ConversationalIndexConstants { + " }\n" + "}"; - /** Name of the feature flag for conversational memory */ - public final static String MEMORY_FEATURE_FLAG_NAME = "plugins.ml_commons.memory_feature_enabled"; + /** Feature Flag setting for conversational memory */ + public static final Setting ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting + .boolSetting("plugins.ml_commons.memory_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); } \ No newline at end of file diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java index ce92e455da..3d3ac5f656 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java @@ -23,6 +23,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; @@ -41,7 +42,8 @@ public class CreateConversationTransportAction extends HandledTransportAction setting = (Setting) clusterService + .getClusterSettings() + .get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()); + this.featureIsEnabled = setting.get(clusterService.getSettings()); + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it); } @Override protected void doExecute(Task task, CreateConversationRequest request, ActionListener actionListener) { - if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) { + if (!featureIsEnabled) { actionListener .onFailure( new OpenSearchException( "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " - + ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() ) ); return; diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java index 18057cddf8..d145d1f770 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java @@ -23,6 +23,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; @@ -41,7 +42,8 @@ public class CreateInteractionTransportAction extends HandledTransportAction setting = (Setting) clusterService + .getClusterSettings() + .get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()); + this.featureIsEnabled = setting.get(clusterService.getSettings()); + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it); } @Override protected void doExecute(Task task, CreateInteractionRequest request, ActionListener actionListener) { - if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) { + if (!featureIsEnabled) { actionListener .onFailure( new OpenSearchException( "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " - + ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() ) ); return; diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportAction.java index 4ba727c222..b9fb725fe6 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportAction.java @@ -23,6 +23,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; @@ -41,7 +42,8 @@ public class DeleteConversationTransportAction extends HandledTransportAction setting = (Setting) clusterService + .getClusterSettings() + .get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()); + this.featureIsEnabled = setting.get(clusterService.getSettings()); + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it); } @Override public void doExecute(Task task, DeleteConversationRequest request, ActionListener listener) { - if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) { + if (!featureIsEnabled) { listener .onFailure( new OpenSearchException( "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " - + ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() ) ); return; diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportAction.java index 24ea9465a6..efd6809e5c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportAction.java @@ -25,6 +25,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationMeta; @@ -44,7 +45,8 @@ public class GetConversationsTransportAction extends HandledTransportAction setting = (Setting) clusterService + .getClusterSettings() + .get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()); + this.featureIsEnabled = setting.get(clusterService.getSettings()); + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it); } @Override public void doExecute(Task task, GetConversationsRequest request, ActionListener actionListener) { - if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) { + if (!featureIsEnabled) { actionListener .onFailure( new OpenSearchException( "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " - + ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() ) ); return; diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportAction.java index af6cf31289..50946af78e 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportAction.java @@ -25,6 +25,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; @@ -44,7 +45,8 @@ public class GetInteractionsTransportAction extends HandledTransportAction setting = (Setting) clusterService + .getClusterSettings() + .get(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()); + this.featureIsEnabled = setting.get(clusterService.getSettings()); + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> featureIsEnabled = it); } @Override public void doExecute(Task task, GetInteractionsRequest request, ActionListener actionListener) { - if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) { + if (!featureIsEnabled) { actionListener .onFailure( new OpenSearchException( "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " - + ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() ) ); return; diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java index 42a2d15016..313071dc45 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -33,6 +34,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -90,13 +92,16 @@ public void setup() throws IOException { this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); this.request = new CreateConversationRequest("test"); - this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); - Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build(); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); when(this.client.threadPool()).thenReturn(this.threadPool); when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); } public void testCreateConversation() { @@ -148,6 +153,9 @@ public void testDoExecuteFails_thenFail() { public void testFeatureDisabled_ThenFail() { when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java index 423bd7e91a..8321a0b65e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -33,6 +34,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -90,13 +92,17 @@ public void setup() throws IOException { this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); this.request = new CreateInteractionRequest("test-cid", "input", "pt", "response", "origin", "metadata"); - this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); - Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build(); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); when(this.client.threadPool()).thenReturn(this.threadPool); when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } public void testCreateInteraction() { @@ -138,6 +144,9 @@ public void testDoExecuteFails_thenFail() { public void testFeatureDisabled_ThenFail() { when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportActionTests.java index 35af9d1c1d..984b9a2fbf 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportActionTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -33,6 +34,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -90,13 +92,16 @@ public void setup() throws IOException { this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); this.request = new DeleteConversationRequest("test"); - this.action = spy(new DeleteConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); - Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build(); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); when(this.client.threadPool()).thenReturn(this.threadPool); when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new DeleteConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); } public void testDeleteConversation() { @@ -134,6 +139,9 @@ public void testdoExecuteFails_thenFail() { public void testFeatureDisabled_ThenFail() { when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new DeleteConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java index bdded0e3f8..41c99bdc74 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.time.Instant; import java.util.List; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -36,6 +37,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -94,13 +96,16 @@ public void setup() throws IOException { this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); this.request = new GetConversationsRequest(); - this.action = spy(new GetConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); - Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build(); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); when(this.client.threadPool()).thenReturn(this.threadPool); when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); } public void testGetConversations() { @@ -189,6 +194,9 @@ public void testdoExecuteFails_thenFail() { public void testFeatureDisabled_ThenFail() { when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java index 8ee88b2a54..a7a245b680 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.time.Instant; import java.util.List; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -36,6 +37,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -94,13 +96,16 @@ public void setup() throws IOException { this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); this.request = new GetInteractionsRequest("test-cid"); - this.action = spy(new GetInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); - Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build(); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); when(this.client.threadPool()).thenReturn(this.threadPool); when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); } public void testGetInteractions_noMorePages() { @@ -180,6 +185,9 @@ public void testDoExecuteFails_thenFail() { public void testFeatureDisabled_ThenFail() { when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 88da0cb748..dc9c209535 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -129,6 +129,5 @@ private MLCommonsSettings() {} Setting.Property.Dynamic ); - public static final Setting ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting - .boolSetting(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_MEMORY_FEATURE_ENABLED = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED; } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java index b527ef7a32..0d909805b8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java @@ -21,13 +21,34 @@ import java.util.Map; import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.utils.TestHelper; +import com.google.common.collect.ImmutableList; + public class RestMemoryCreateConversationActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + super.setupSettings(); + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + public void testCreateConversation() throws IOException { Response response = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); assert (response != null); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java index 32f2b484dd..dbef57d81e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java @@ -21,13 +21,34 @@ import java.util.Map; import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.utils.TestHelper; +import com.google.common.collect.ImmutableList; + public class RestMemoryCreateInteractionActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + super.setupSettings(); + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + public void testCreateInteraction() throws IOException { Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); assert (ccresponse != null); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java index f33e163b67..7633d30ac9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java @@ -22,13 +22,34 @@ import java.util.Map; import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.utils.TestHelper; +import com.google.common.collect.ImmutableList; + public class RestMemoryDeleteConversationActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + super.setupSettings(); + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + public void testDeleteConversation_ThatExists() throws IOException { Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); assert (ccresponse != null); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java index 3ce0eb8ff0..5273f74472 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java @@ -22,13 +22,34 @@ import java.util.Map; import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.utils.TestHelper; +import com.google.common.collect.ImmutableList; + public class RestMemoryGetConversationsActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + super.setupSettings(); + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + public void testNoConversations_EmptyList() throws IOException { Response response = TestHelper.makeRequest(client(), "GET", ActionConstants.GET_CONVERSATIONS_REST_PATH, null, "", null); assert (response != null); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java index a4ee6118d3..ac269b5e7c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java @@ -22,13 +22,34 @@ import java.util.Map; import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.utils.TestHelper; +import com.google.common.collect.ImmutableList; + public class RestMemoryGetInteractionsActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + super.setupSettings(); + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + public void testGetInteractions_NoConversation() throws IOException { Response response = TestHelper .makeRequest(