Skip to content

Commit

Permalink
[Backport 2.x] Added an optional workflow_step param to the get workf…
Browse files Browse the repository at this point in the history
…low steps API (#542)

Added an optional workflow_step param to the get workflow steps API (#538)

* Added optional step param to get the workflow steps API



* Fixed api response



* Added tests



* Added CHANGELOG



* Addressed PR comments



* Added another test



* Logged exception message



---------


(cherry picked from commit 2c6dfd5)

Signed-off-by: Owais Kazi <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 2eb8769 commit 2540db6
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)

### Enhancements
- Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525))
- Added an optional workflow_step param to the get workflow steps API ([#538](https://github.com/opensearch-project/flow-framework/pull/538))

### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ private CommonValue() {}
public static final String VALIDATION = "validation";
/** The field name for provision workflow within a use case template*/
public static final String PROVISION_WORKFLOW = "provision";
/** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */
public static final String WORKFLOW_STEP = "workflow_step";

/*
* Constants associated with plugin configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
Expand All @@ -21,14 +19,18 @@
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.GetWorkflowStepAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;

Expand Down Expand Up @@ -60,7 +62,7 @@ public List<Route> routes() {
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
try {
if (!flowFrameworkSettings.isFlowFrameworkEnabled()) {
throw new FlowFrameworkException(
Expand All @@ -69,13 +71,12 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
);
}

ActionRequest request = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}
};
return channel -> client.execute(GetWorkflowStepAction.INSTANCE, request, ActionListener.wrap(response -> {
Map<String, String> params = request.hasParam(WORKFLOW_STEP)
? Map.of(WORKFLOW_STEP, request.param(WORKFLOW_STEP))
: Collections.emptyMap();

WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params);
return channel -> client.execute(GetWorkflowStepAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
}, exception -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,27 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.model.WorkflowValidator;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP;

/**
* Transport action to retrieve a workflow step json
*/
public class GetWorkflowStepTransportAction extends HandledTransportAction<ActionRequest, GetWorkflowStepResponse> {
public class GetWorkflowStepTransportAction extends HandledTransportAction<WorkflowRequest, GetWorkflowStepResponse> {

private final Logger logger = LogManager.getLogger(GetWorkflowStepTransportAction.class);
private final WorkflowStepFactory workflowStepFactory;
Expand All @@ -47,11 +53,23 @@ public GetWorkflowStepTransportAction(
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<GetWorkflowStepResponse> listener) {
protected void doExecute(Task task, WorkflowRequest request, ActionListener<GetWorkflowStepResponse> listener) {
try {
WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidator();
List<String> steps = request.getParams().size() > 0
? Arrays.asList(Strings.splitStringByCommaToArray(request.getParams().get(WORKFLOW_STEP)))
: Collections.emptyList();
WorkflowValidator workflowValidator;
if (steps.isEmpty()) {
workflowValidator = this.workflowStepFactory.getWorkflowValidator();
} else {
workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps);
}
listener.onResponse(new GetWorkflowStepResponse(workflowValidator));
} catch (Exception e) {
if (e instanceof FlowFrameworkException) {
logger.error(e.getMessage());
listener.onFailure(e);
}
logger.error("Failed to retrieve workflow step json.", e);
listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template)
* @param template the use case template which describes the workflow
* @param params The parameters from the REST path
*/
public WorkflowRequest(String workflowId, @Nullable Template template, Map<String, String> params) {
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map<String, String> params) {
this(workflowId, template, new String[] { "all" }, true, params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;

import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD;
Expand Down Expand Up @@ -364,6 +366,30 @@ public WorkflowValidator getWorkflowValidator() {
return new WorkflowValidator(workflowStepValidators);
}

/**
* Get the object of WorkflowValidator consisting of passed workflow steps
* @param steps workflow steps
* @return WorkflowValidator
*/
public WorkflowValidator getWorkflowValidatorByStep(List<String> steps) {
Map<String, WorkflowStepValidator> workflowStepValidators = new HashMap<>();
Set<String> invalidSteps = new HashSet<>(steps);

for (WorkflowSteps mapping : WorkflowSteps.values()) {
String step = mapping.getWorkflowStepName();
if (steps.contains(step)) {
workflowStepValidators.put(mapping.getWorkflowStepName(), mapping.getWorkflowStepValidator());
invalidSteps.remove(step);
}
}

if (!invalidSteps.isEmpty()) {
throw new FlowFrameworkException("Invalid step name: " + invalidSteps, RestStatus.BAD_REQUEST);
}

return new WorkflowValidator(workflowStepValidators);
}

/**
* Create a new instance of a {@link WorkflowStep}.
* @param type The type of instance to create
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,32 @@
package org.opensearch.flowframework.rest;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.WorkflowValidator;
import org.opensearch.flowframework.transport.GetWorkflowStepResponse;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestChannel;
import org.opensearch.test.rest.FakeRestRequest;
import org.opensearch.threadpool.ThreadPool;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -31,12 +43,22 @@ public class RestGetWorkflowStepActionTests extends OpenSearchTestCase {
private String getPath;
private FlowFrameworkSettings flowFrameworkFeatureEnabledSetting;
private NodeClient nodeClient;
private WorkflowStepFactory workflowStepFactory;
private FlowFrameworkSettings flowFrameworkSettings;

@Override
public void setUp() throws Exception {
super.setUp();

flowFrameworkSettings = mock(FlowFrameworkSettings.class);
when(flowFrameworkSettings.isFlowFrameworkEnabled()).thenReturn(true);

this.getPath = String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, "_steps");
ThreadPool threadPool = mock(ThreadPool.class);
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);

this.workflowStepFactory = new WorkflowStepFactory(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings);
flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkSettings.class);
when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true);
this.restGetWorkflowStepAction = new RestGetWorkflowStepAction(flowFrameworkFeatureEnabledSetting);
Expand Down Expand Up @@ -68,6 +90,46 @@ public void testInvalidRequestWithContent() {
assertEquals("request [GET /_plugins/_flow_framework/workflow/_steps] does not support having a body", ex.getMessage());
}

public void testWorkflowSteps() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET)
.withPath(this.getPath + "?workflow_step=create_connector")
.build();

FakeRestChannel channel = new FakeRestChannel(request, false, 1);
List<String> steps = new ArrayList<>();
steps.add("create_connector");
doAnswer(invocation -> {
ActionListener<GetWorkflowStepResponse> actionListener = invocation.getArgument(2);
WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps);
actionListener.onResponse(new GetWorkflowStepResponse(workflowValidator));
return null;
}).when(nodeClient).execute(any(), any(WorkflowRequest.class), any());
restGetWorkflowStepAction.handleRequest(request, channel, nodeClient);
assertEquals(RestStatus.OK, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("create_connector"));
}

public void testFailedWorkflowSteps() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET)
.withPath(this.getPath + "?workflow_step=xyz")
.build();

FakeRestChannel channel = new FakeRestChannel(request, false, 1);
List<String> steps = new ArrayList<>();
steps.add("xyz");
doAnswer(invocation -> {
ActionListener<GetWorkflowStepResponse> actionListener = invocation.getArgument(2);
WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidatorByStep(steps);
actionListener.onResponse(new GetWorkflowStepResponse(workflowValidator));
return null;
}).when(nodeClient).execute(any(), any(WorkflowRequest.class), any());
FlowFrameworkException exception = expectThrows(
FlowFrameworkException.class,
() -> restGetWorkflowStepAction.handleRequest(request, channel, nodeClient)
);
assertEquals("Invalid step name: [xyz]", exception.getMessage());
}

public void testFeatureFlagNotEnabled() throws Exception {
when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false);
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.GET)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.mockito.ArgumentCaptor;

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STEP;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

@SuppressWarnings("unchecked")
public class GetWorkflowStepTransportActionTests extends OpenSearchTestCase {

private GetWorkflowStepTransportAction getWorkflowStepTransportAction;
Expand All @@ -45,6 +49,18 @@ public void testGetWorkflowStepAction() throws IOException {

ArgumentCaptor<GetWorkflowStepResponse> stepCaptor = ArgumentCaptor.forClass(GetWorkflowStepResponse.class);
verify(listener, times(1)).onResponse(stepCaptor.capture());
}

public void testGetWorkflowStepValidator() throws IOException {
Map<String, String> params = new HashMap<>();
params.put(WORKFLOW_STEP, "create_connector, delete_model");

WorkflowRequest workflowRequest = new WorkflowRequest(null, null, params);
ActionListener<GetWorkflowStepResponse> listener = mock(ActionListener.class);
getWorkflowStepTransportAction.doExecute(mock(Task.class), workflowRequest, listener);
ArgumentCaptor<GetWorkflowStepResponse> stepCaptor = ArgumentCaptor.forClass(GetWorkflowStepResponse.class);
verify(listener, times(1)).onResponse(stepCaptor.capture());
assertEquals(GetWorkflowStepResponse.class, stepCaptor.getValue().getClass());

}
}

0 comments on commit 2540db6

Please sign in to comment.