From 2b2080e3a6be459161f3d3dfa4e2b20172133131 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 15 Feb 2024 13:53:57 -0800 Subject: [PATCH 1/9] Include params map in WorkflowRequest when provisioning Signed-off-by: Daniel Widdis --- .../rest/RestCreateWorkflowAction.java | 24 +++++++++- .../rest/RestProvisionWorkflowAction.java | 12 ++++- .../transport/WorkflowRequest.java | 46 +++++++++++++++++-- .../CreateWorkflowTransportActionTests.java | 14 +++--- .../WorkflowRequestResponseTests.java | 43 +++++++++++++++++ 5 files changed, 127 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 490f4f3ff..595ed4932 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -27,8 +27,11 @@ import org.opensearch.rest.RestRequest; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; @@ -75,6 +78,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String workflowId = request.param(WORKFLOW_ID); String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" }); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); + Map params = Collections.emptyMap(); + final List validCreateParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW); + // If provisioning, consume all other params and pass to provision transport action + if (provision) { + params = request.params() + .entrySet() + .stream() + .filter(e -> !validCreateParams.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { FlowFrameworkException ffe = new FlowFrameworkException( "This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", @@ -84,12 +97,21 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) ); } + if (!provision && !params.isEmpty()) { + FlowFrameworkException ffe = new FlowFrameworkException( + "Only the parameters " + validCreateParams + " are permitted unless the provision parameter is set to true.", + RestStatus.BAD_REQUEST + ); + return channel -> channel.sendResponse( + new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } try { XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); Template template = Template.parse(parser); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision, params); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 124b6bf49..7c0fdb3d4 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -27,7 +27,11 @@ import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.VALIDATION; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -68,6 +72,12 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); + final List excludeParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW); + Map params = request.params() + .entrySet() + .stream() + .filter(e -> !WORKFLOW_ID.equals(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); try { if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( @@ -85,7 +95,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } // Create request and provision - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params); return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 341c79742..6268dab2c 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -16,6 +16,8 @@ import org.opensearch.flowframework.model.Template; import java.io.IOException; +import java.util.Collections; +import java.util.Map; /** * Transport Request to create, provision, and deprovision a workflow @@ -43,12 +45,27 @@ public class WorkflowRequest extends ActionRequest { private boolean provision; /** - * Instantiates a new WorkflowRequest, set validation to false and set requestTimeout and maxWorkflows to null + * Params map + */ + private Map params; + + /** + * Instantiates a new WorkflowRequest, set validation to all, no provisioning * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false); + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap()); + } + + /** + * Instantiates a new WorkflowRequest with params map, set validation to all, provisioning to true + * @param workflowId the documentId of the workflow + * @param template the use case template which describes the workflow + * @param params The parameters from the REST path + */ + public WorkflowRequest(String workflowId, @Nullable Template template, Map params) { + this(workflowId, template, new String[] { "all" }, true, params); } /** @@ -57,12 +74,23 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param template the use case template which describes the workflow * @param validation flag to indicate if validation is necessary * @param provision flag to indicate if provision is necessary + * @param params map of REST path params. If provision is false, must be an empty map. */ - public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, String[] validation, boolean provision) { + public WorkflowRequest( + @Nullable String workflowId, + @Nullable Template template, + String[] validation, + boolean provision, + Map params + ) { this.workflowId = workflowId; this.template = template; this.validation = validation; this.provision = provision; + if (!provision && !params.isEmpty()) { + throw new IllegalArgumentException("Params may only be included when provisioning."); + } + this.params = params; } /** @@ -77,6 +105,7 @@ public WorkflowRequest(StreamInput in) throws IOException { this.template = templateJson == null ? null : Template.parse(templateJson); this.validation = in.readStringArray(); this.provision = in.readBoolean(); + this.params = this.provision ? in.readMap(StreamInput::readString, StreamInput::readString) : Collections.emptyMap(); } /** @@ -113,6 +142,14 @@ public boolean isProvision() { return this.provision; } + /** + * Gets the params map + * @return the params map + */ + public Map getParams() { + return Map.copyOf(this.params); + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -120,6 +157,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(template == null ? null : template.toJson()); out.writeStringArray(validation); out.writeBoolean(provision); + if (provision) { + out.writeMap(params, StreamOutput::writeString, StreamOutput::writeString); + } } @Override diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index e8c8ba4f3..bd4615284 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -132,7 +132,7 @@ public void testValidation_withoutProvision_Success() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate, new String[] { "all" }, false); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, validTemplate); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); } @@ -192,7 +192,7 @@ public void testValidation_Failed() throws Exception { ActionListener listener = mock(ActionListener.class); // Stub validation failure doThrow(Exception.class).when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, new String[] { "all" }, false); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); verify(listener, times(1)).onFailure(any()); @@ -203,7 +203,7 @@ public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap()); doAnswer(invocation -> { ActionListener searchListener = invocation.getArgument(1); @@ -240,7 +240,7 @@ public void onFailure(Exception e) { public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap()); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -271,7 +271,7 @@ public void testFailedToCreateNewWorkflow() { public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap()); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -359,7 +359,7 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc ActionListener listener = mock(ActionListener.class); doNothing().when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true); + WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true, Collections.emptyMap()); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -412,7 +412,7 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); doNothing().when(workflowProcessSorter).validate(any(), any()); - WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true); + WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true, Collections.emptyMap()); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java index 362c3feae..200312bec 100644 --- a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -61,6 +61,8 @@ public void testNullIdWorkflowRequest() throws IOException { assertNull(nullIdRequest.getWorkflowId()); assertEquals(template, nullIdRequest.getTemplate()); assertNull(nullIdRequest.validate()); + assertFalse(nullIdRequest.isProvision()); + assertTrue(nullIdRequest.getParams().isEmpty()); BytesStreamOutput out = new BytesStreamOutput(); nullIdRequest.writeTo(out); @@ -70,6 +72,9 @@ public void testNullIdWorkflowRequest() throws IOException { assertEquals(nullIdRequest.getWorkflowId(), streamInputRequest.getWorkflowId()); assertEquals(nullIdRequest.getTemplate().toJson(), streamInputRequest.getTemplate().toJson()); + assertNull(nullIdRequest.validate()); + assertFalse(nullIdRequest.isProvision()); + assertTrue(nullIdRequest.getParams().isEmpty()); } public void testNullTemplateWorkflowRequest() throws IOException { @@ -77,6 +82,8 @@ public void testNullTemplateWorkflowRequest() throws IOException { assertNotNull(nullTemplateRequest.getWorkflowId()); assertNull(nullTemplateRequest.getTemplate()); assertNull(nullTemplateRequest.validate()); + assertFalse(nullTemplateRequest.isProvision()); + assertTrue(nullTemplateRequest.getParams().isEmpty()); BytesStreamOutput out = new BytesStreamOutput(); nullTemplateRequest.writeTo(out); @@ -86,6 +93,9 @@ public void testNullTemplateWorkflowRequest() throws IOException { assertEquals(nullTemplateRequest.getWorkflowId(), streamInputRequest.getWorkflowId()); assertEquals(nullTemplateRequest.getTemplate(), streamInputRequest.getTemplate()); + assertNull(nullTemplateRequest.validate()); + assertFalse(nullTemplateRequest.isProvision()); + assertTrue(nullTemplateRequest.getParams().isEmpty()); } public void testWorkflowRequest() throws IOException { @@ -93,6 +103,29 @@ public void testWorkflowRequest() throws IOException { assertNotNull(workflowRequest.getWorkflowId()); assertEquals(template, workflowRequest.getTemplate()); assertNull(workflowRequest.validate()); + assertFalse(workflowRequest.isProvision()); + assertTrue(workflowRequest.getParams().isEmpty()); + + BytesStreamOutput out = new BytesStreamOutput(); + workflowRequest.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + + WorkflowRequest streamInputRequest = new WorkflowRequest(in); + + assertEquals(workflowRequest.getWorkflowId(), streamInputRequest.getWorkflowId()); + assertEquals(workflowRequest.getTemplate().toJson(), streamInputRequest.getTemplate().toJson()); + assertNull(workflowRequest.validate()); + assertFalse(workflowRequest.isProvision()); + assertTrue(workflowRequest.getParams().isEmpty()); + } + + public void testWorkflowRequestWithParams() throws IOException { + WorkflowRequest workflowRequest = new WorkflowRequest("123", template, Map.of("foo", "bar")); + assertNotNull(workflowRequest.getWorkflowId()); + assertEquals(template, workflowRequest.getTemplate()); + assertNull(workflowRequest.validate()); + assertTrue(workflowRequest.isProvision()); + assertEquals("bar", workflowRequest.getParams().get("foo")); BytesStreamOutput out = new BytesStreamOutput(); workflowRequest.writeTo(out); @@ -102,7 +135,17 @@ public void testWorkflowRequest() throws IOException { assertEquals(workflowRequest.getWorkflowId(), streamInputRequest.getWorkflowId()); assertEquals(workflowRequest.getTemplate().toJson(), streamInputRequest.getTemplate().toJson()); + assertNull(workflowRequest.validate()); + assertTrue(workflowRequest.isProvision()); + assertEquals("bar", workflowRequest.getParams().get("foo")); + } + public void testWorkflowRequestWithParamsNoProvision() throws IOException { + IllegalArgumentException ex = assertThrows( + IllegalArgumentException.class, + () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar")) + ); + assertEquals("Params may only be included when provisioning.", ex.getMessage()); } public void testWorkflowResponse() throws IOException { From a8c3f3c5580770a803ff4f5e4532de5e02384189 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 15 Feb 2024 21:47:50 -0800 Subject: [PATCH 2/9] Pass params to ProcessNode Signed-off-by: Daniel Widdis --- .../transport/CreateWorkflowTransportAction.java | 3 ++- .../DeprovisionWorkflowTransportAction.java | 2 ++ .../transport/ProvisionWorkflowTransportAction.java | 6 +++++- .../flowframework/workflow/ProcessNode.java | 12 ++++++++++++ .../workflow/WorkflowProcessSorter.java | 4 +++- .../flowframework/workflow/ProcessNodeTests.java | 6 ++++++ .../workflow/WorkflowProcessSorterTests.java | 10 +++++----- 7 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index e2f766917..c7123b523 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -39,6 +39,7 @@ import org.opensearch.transport.TransportService; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -282,7 +283,7 @@ void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, ActionList private void validateWorkflows(Template template) throws Exception { for (Workflow workflow : template.workflows().values()) { - List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null); + List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null, Collections.emptyMap()); workflowProcessSorter.validate(sortedNodes, pluginsService); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index 821ec39c8..d10233785 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -140,6 +140,7 @@ private void executeDeprovisionSequence( deprovisionStepId, workflowStepFactory.createStep(deprovisionStep), Collections.emptyMap(), + Collections.emptyMap(), new WorkflowData(Map.of(getResourceByWorkflowStep(stepName), resource.resourceId()), workflowId, deprovisionStepId), Collections.emptyList(), this.threadPool, @@ -194,6 +195,7 @@ private void executeDeprovisionSequence( pn.id(), workflowStepFactory.createStep(pn.workflowStep().getName()), pn.previousNodeInputs(), + pn.params(), pn.input(), pn.predecessors(), this.threadPool, diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index efa4b8e6b..a3aef42c6 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -126,7 +126,11 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); + List provisionProcessSequence = workflowProcessSorter.sortProcessNodes( + provisionWorkflow, + workflowId, + request.getParams() + ); workflowProcessSorter.validate(provisionProcessSequence, pluginsService); flowFrameworkIndicesHandler.isWorkflowNotStarted(workflowId, workflowIsNotStarted -> { diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index 2349b9fb7..470775b90 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -30,6 +30,7 @@ public class ProcessNode { private final String id; private final WorkflowStep workflowStep; private final Map previousNodeInputs; + private final Map params; private final WorkflowData input; private final List predecessors; private final ThreadPool threadPool; @@ -44,6 +45,7 @@ public class ProcessNode { * @param id A string identifying the workflow step * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. * @param previousNodeInputs A map of expected inputs coming from predecessor nodes used in graph validation + * @param params Params passed on the REST path * @param input Input required by the node encoded in a {@link WorkflowData} instance. * @param predecessors Nodes preceding this one in the workflow * @param threadPool The OpenSearch thread pool @@ -54,6 +56,7 @@ public ProcessNode( String id, WorkflowStep workflowStep, Map previousNodeInputs, + Map params, WorkflowData input, List predecessors, ThreadPool threadPool, @@ -63,6 +66,7 @@ public ProcessNode( this.id = id; this.workflowStep = workflowStep; this.previousNodeInputs = previousNodeInputs; + this.params = params; this.input = input; this.predecessors = predecessors; this.threadPool = threadPool; @@ -94,6 +98,14 @@ public Map previousNodeInputs() { return previousNodeInputs; } + /** + * Returns the REST path params + * @return the REST path params + */ + public Map params() { + return params; + } + /** * Returns the input data for this node. * @return the input data diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index ac6d75d58..26e60cd91 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -87,9 +87,10 @@ public WorkflowProcessSorter( * Sort a workflow into a topologically sorted list of process nodes. * @param workflow A workflow with (unsorted) nodes and edges which define predecessors and successors * @param workflowId The workflowId associated with the step + * @param params Parameters passed on the REST path * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ - public List sortProcessNodes(Workflow workflow, String workflowId) { + public List sortProcessNodes(Workflow workflow, String workflowId, Map params) { if (workflow.nodes().size() > this.maxWorkflowSteps) { throw new FlowFrameworkException( "Workflow " @@ -122,6 +123,7 @@ public List sortProcessNodes(Workflow workflow, String workflowId) ProcessNode processNode = new ProcessNode( node.id(), step, + params, node.previousNodeInputs(), data, predecessorNodes, diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 49aa4aaed..c2f2ba760 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -86,6 +86,7 @@ public String getName() { return "test"; } }, + Collections.emptyMap(), Collections.emptyMap(), new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id", "test-node-id"), List.of(successfulNode), @@ -95,6 +96,8 @@ public String getName() { ); assertEquals("A", nodeA.id()); assertEquals("test", nodeA.workflowStep().getName()); + assertEquals(Collections.emptyMap(), nodeA.previousNodeInputs()); + assertEquals(Collections.emptyMap(), nodeA.params()); assertEquals("input", nodeA.input().getContent().get("test")); assertEquals("bar", nodeA.input().getParams().get("foo")); assertEquals("test-id", nodeA.input().getWorkflowId()); @@ -132,6 +135,7 @@ public String getName() { return "test"; } }, + Collections.emptyMap(), Collections.emptyMap(), WorkflowData.EMPTY, Collections.emptyList(), @@ -174,6 +178,7 @@ public String getName() { return "sleepy"; } }, + Collections.emptyMap(), Collections.emptyMap(), WorkflowData.EMPTY, Collections.emptyList(), @@ -213,6 +218,7 @@ public String getName() { return "test"; } }, + Collections.emptyMap(), Collections.emptyMap(), WorkflowData.EMPTY, List.of(successfulNode, failedNode), diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 6e06f252c..02488a739 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -73,7 +73,7 @@ private static Workflow parseToWorkflow(String json) throws IOException { // Wrap parser into node list private static List parseToNodes(String json) throws IOException { - return workflowProcessSorter.sortProcessNodes(parseToWorkflow(json), "123"); + return workflowProcessSorter.sortProcessNodes(parseToWorkflow(json), "123", Collections.emptyMap()); } // Wrap parser into string list @@ -376,7 +376,7 @@ public void testSuccessfulGraphValidation() throws Exception { List.of(edge1, edge2) ); - List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap()); workflowProcessSorter.validateGraph(sortedProcessNodes); } @@ -398,7 +398,7 @@ public void testFailedGraphValidation() throws IOException { WorkflowEdge edge = new WorkflowEdge(registerModel.id(), deployModel.id()); Workflow workflow = new Workflow(Collections.emptyMap(), List.of(registerModel, deployModel), List.of(edge)); - List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap()); FlowFrameworkException ex = expectThrows( FlowFrameworkException.class, () -> workflowProcessSorter.validateGraph(sortedProcessNodes) @@ -444,7 +444,7 @@ public void testSuccessfulInstalledPluginValidation() throws Exception { List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2) ); - List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap()); workflowProcessSorter.validatePluginsInstalled(sortedProcessNodes, List.of("opensearch-flow-framework", "opensearch-ml")); } @@ -486,7 +486,7 @@ public void testFailedInstalledPluginValidation() throws Exception { List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2) ); - List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap()); FlowFrameworkException exception = expectThrows( FlowFrameworkException.class, From 5e4a26d81fc560fce665fe23a1a519d5bfc590ed Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 15 Feb 2024 22:06:46 -0800 Subject: [PATCH 3/9] Pass params to WorkflowSteps Signed-off-by: Daniel Widdis --- .../opensearch/flowframework/util/ParseUtils.java | 4 +++- .../workflow/AbstractRegisterLocalModelStep.java | 6 ++++-- .../workflow/CreateConnectorStep.java | 6 ++++-- .../flowframework/workflow/CreateIndexStep.java | 3 ++- .../workflow/CreateIngestPipelineStep.java | 3 ++- .../flowframework/workflow/DeleteAgentStep.java | 6 ++++-- .../workflow/DeleteConnectorStep.java | 6 ++++-- .../flowframework/workflow/DeleteModelStep.java | 6 ++++-- .../flowframework/workflow/DeployModelStep.java | 6 ++++-- .../flowframework/workflow/NoOpStep.java | 3 ++- .../flowframework/workflow/ProcessNode.java | 3 ++- .../flowframework/workflow/RegisterAgentStep.java | 14 +++++++++----- .../workflow/RegisterModelGroupStep.java | 6 ++++-- .../workflow/RegisterRemoteModelStep.java | 6 ++++-- .../flowframework/workflow/ToolStep.java | 6 ++++-- .../flowframework/workflow/UndeployModelStep.java | 6 ++++-- .../flowframework/workflow/WorkflowStep.java | 4 +++- .../DeprovisionWorkflowTransportActionTests.java | 4 ++-- .../flowframework/util/ParseUtilsTests.java | 15 ++++++++++++--- .../workflow/CreateConnectorStepTests.java | 2 ++ .../workflow/CreateIndexStepTests.java | 2 ++ .../workflow/CreateIngestPipelineStepTests.java | 3 +++ .../workflow/DeleteAgentStepTests.java | 7 +++++-- .../workflow/DeleteConnectorStepTests.java | 7 +++++-- .../workflow/DeleteModelStepTests.java | 7 +++++-- .../workflow/DeployModelStepTests.java | 3 +++ .../workflow/ModelGroupStepTests.java | 3 +++ .../flowframework/workflow/NoOpStepTests.java | 1 + .../flowframework/workflow/ProcessNodeTests.java | 12 ++++++++---- .../workflow/RegisterAgentTests.java | 2 ++ .../RegisterLocalCustomModelStepTests.java | 4 ++++ .../RegisterLocalPretrainedModelStepTests.java | 4 ++++ ...RegisterLocalSparseEncodingModelStepTests.java | 4 ++++ .../workflow/RegisterRemoteModelStepTests.java | 4 ++++ .../flowframework/workflow/ToolStepTests.java | 1 + .../workflow/UndeployModelStepTests.java | 7 +++++-- 36 files changed, 138 insertions(+), 48 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index fc7536177..b673a55a1 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -248,6 +248,7 @@ public static Map getStringToStringMap(Object map, String fieldN * @param currentNodeInputs Input params and content for this node, from workflow parsing * @param outputs WorkflowData content of previous steps * @param previousNodeInputs Input params for this node that come from previous steps + * @param params Params that came from REST path * @return A map containing the requiredInputKeys with their corresponding values, * and optionalInputKeys with their corresponding values if present. * Throws a {@link FlowFrameworkException} if a required key is not present. @@ -257,7 +258,8 @@ public static Map getInputsFromPreviousSteps( Set optionalInputKeys, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { // Mutable set to ensure all required keys are used Set requiredKeys = new HashSet<>(requiredInputKeys); diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java index 5074f3efa..64e3520a6 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java @@ -79,7 +79,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture registerLocalModelFuture = PlainActionFuture.newFuture(); @@ -90,7 +91,8 @@ public PlainActionFuture execute( getOptionalKeys(), currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); // Extract common fields of OS provided text-embedding, sparse encoding and custom models diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 228b4161f..12daee204 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -74,7 +74,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture createConnectorFuture = PlainActionFuture.newFuture(); @@ -138,7 +139,8 @@ public void onFailure(Exception e) { optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String name = (String) inputs.get(NAME_FIELD); diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index eee1f94ec..2afb30077 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -63,7 +63,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture createIndexFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index d0bbed40b..138716dfb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -70,7 +70,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture createIngestPipelineFuture = PlainActionFuture.newFuture(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java index 04c1cca92..3965e4bd3 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java @@ -49,7 +49,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture deleteAgentFuture = PlainActionFuture.newFuture(); @@ -82,7 +83,8 @@ public void onFailure(Exception e) { optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String agentId = (String) inputs.get(AGENT_ID); diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java index 6c3376369..81a5bf425 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java @@ -49,7 +49,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture deleteConnectorFuture = PlainActionFuture.newFuture(); @@ -82,7 +83,8 @@ public void onFailure(Exception e) { optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String connectorId = (String) inputs.get(CONNECTOR_ID); diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java index be8e66138..f14753811 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java @@ -49,7 +49,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture deleteModelFuture = PlainActionFuture.newFuture(); @@ -82,7 +83,8 @@ public void onFailure(Exception e) { optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String modelId = inputs.get(MODEL_ID).toString(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index dff2ff92e..aefc3d705 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -64,7 +64,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture deployModelFuture = PlainActionFuture.newFuture(); @@ -109,7 +110,8 @@ public void onFailure(Exception e) { optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String modelId = (String) inputs.get(MODEL_ID); diff --git a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java index e13181cf7..e93aba1cc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java @@ -28,7 +28,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture future = PlainActionFuture.newFuture(); future.onResponse(WorkflowData.EMPTY); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index 470775b90..454bdb2d1 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -171,7 +171,8 @@ public PlainActionFuture execute() { this.id, this.input, inputMap, - this.previousNodeInputs + this.previousNodeInputs, + this.params ); // If completed exceptionally, this is a no-op future.onResponse(stepFuture.actionGet(this.nodeTimeout)); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 8c36575a4..6ca558edd 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -82,7 +82,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { String workflowId = currentNodeInputs.getWorkflowId(); @@ -150,7 +151,8 @@ public void onFailure(Exception e) { optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String type = (String) inputs.get(TYPE); @@ -163,8 +165,10 @@ public void onFailure(Exception e) { : getStringToStringMap(llmParams, LLM_PARAMETERS); String[] toolsOrder = (String[]) inputs.get(TOOLS_ORDER_FIELD); List toolsList = getTools(toolsOrder, previousNodeInputs, outputs); - Object params = inputs.get(PARAMETERS_FIELD); - Map parameters = params == null ? Collections.emptyMap() : getStringToStringMap(params, PARAMETERS_FIELD); + Object parameters = inputs.get(PARAMETERS_FIELD); + Map parametersMap = parameters == null + ? Collections.emptyMap() + : getStringToStringMap(parameters, PARAMETERS_FIELD); MLMemorySpec memory = getMLMemorySpec(inputs.get(MEMORY_FIELD)); Instant createdTime = Instant.now(); Instant lastUpdateTime = createdTime; @@ -191,7 +195,7 @@ public void onFailure(Exception e) { builder.type(type) .tools(toolsList) - .parameters(parameters) + .parameters(parametersMap) .createdTime(createdTime) .lastUpdateTime(lastUpdateTime) .appType(appType); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java index 9acda1b6c..e04c9a916 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java @@ -65,7 +65,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture registerModelGroupFuture = PlainActionFuture.newFuture(); @@ -122,7 +123,8 @@ public void onFailure(Exception e) { optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String modelGroupName = (String) inputs.get(NAME_FIELD); String description = (String) inputs.get(DESCRIPTION_FIELD); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 8cd184a18..b473b5922 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -63,7 +63,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture registerRemoteModelFuture = PlainActionFuture.newFuture(); @@ -77,7 +78,8 @@ public PlainActionFuture execute( optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String modelName = (String) inputs.get(NAME_FIELD); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 2a5536638..6809e7832 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -42,7 +42,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { Set requiredKeys = Set.of(TYPE); Set optionalKeys = Set.of(NAME_FIELD, DESCRIPTION_FIELD, PARAMETERS_FIELD, INCLUDE_OUTPUT_IN_AGENT_RESPONSE); @@ -53,7 +54,8 @@ public PlainActionFuture execute( optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String type = (String) inputs.get(TYPE); diff --git a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java index a90ff1aa8..614b36d57 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java @@ -54,7 +54,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture undeployModelFuture = PlainActionFuture.newFuture(); @@ -95,7 +96,8 @@ public void onFailure(Exception e) { optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); String modelId = inputs.get(MODEL_ID).toString(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index 16cc2b200..ebc8be094 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -23,13 +23,15 @@ public interface WorkflowStep { * @param currentNodeInputs Input params and content for this node, from workflow parsing * @param outputs WorkflowData content of previous steps. * @param previousNodeInputs Input params for this node that come from previous steps + * @param params Params passed on the REST path * @return A CompletableFuture of the building block. This block should return immediately, but not be completed until the step executes, containing either the step's output data or {@link WorkflowData#EMPTY} which may be passed to follow-on steps. */ PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ); /** diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java index b14dd2bb1..80c1ed0b8 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -124,7 +124,7 @@ public void testDeprovisionWorkflow() throws Exception { PlainActionFuture future = PlainActionFuture.newFuture(); future.onResponse(WorkflowData.EMPTY); - when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn(future); + when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap(), anyMap())).thenReturn(future); deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); @@ -154,7 +154,7 @@ public void testFailToDeprovision() throws Exception { PlainActionFuture future = PlainActionFuture.newFuture(); future.onFailure(new RuntimeException("rte")); - when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn(future); + when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap(), anyMap())).thenReturn(future); deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 3431568e6..4f9bbddc6 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -111,13 +111,14 @@ public void testGetInputsFromPreviousSteps() { Map previousNodeInputs = Map.of("step2", "output2"); Set requiredKeys = Set.of("param1", "content1"); Set optionalKeys = Set.of("output1", "output2", "content3", "nestedMap", "nestedList", "no-output"); - + Map params = Map.of("param1", "value1"); Map inputs = ParseUtils.getInputsFromPreviousSteps( requiredKeys, optionalKeys, currentNodeInputs, outputs, - previousNodeInputs + previousNodeInputs, + params ); assertEquals("value1", inputs.get("param1")); @@ -125,6 +126,7 @@ public void testGetInputsFromPreviousSteps() { assertEquals("outputvalue1", inputs.get("output1")); assertEquals("step2outputvalue2", inputs.get("output2")); + // FIXME add a substitution test for params here // Substitutions assertEquals("outputvalue1", inputs.get("content3")); @SuppressWarnings("unchecked") @@ -138,7 +140,14 @@ public void testGetInputsFromPreviousSteps() { Set missingRequiredKeys = Set.of("not-here"); FlowFrameworkException e = assertThrows( FlowFrameworkException.class, - () -> ParseUtils.getInputsFromPreviousSteps(missingRequiredKeys, optionalKeys, currentNodeInputs, outputs, previousNodeInputs) + () -> ParseUtils.getInputsFromPreviousSteps( + missingRequiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params + ) ); assertEquals("Missing required inputs [not-here] in workflow [workflowId] node [nodeId]", e.getMessage()); assertEquals(RestStatus.BAD_REQUEST, e.getRestStatus()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 242dfd02d..86dc4af47 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -103,6 +103,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -125,6 +126,7 @@ public void testCreateConnectorFailure() throws IOException { inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 7eb02891c..4e035149e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -113,6 +113,7 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); assertFalse(future.isDone()); @@ -133,6 +134,7 @@ public void testCreateIndexStepFailure() throws ExecutionException, InterruptedE inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); assertFalse(future.isDone()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index a36e293b9..26ec7aae8 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -94,6 +94,7 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -117,6 +118,7 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -152,6 +154,7 @@ public void testMissingData() throws InterruptedException { incorrectData.getNodeId(), incorrectData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java index 8121f53be..9e75a2dac 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java @@ -64,7 +64,8 @@ public void testDeleteAgent() throws IOException, ExecutionException, Interrupte inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(AGENT_ID, agentId), "workflowId", "nodeId")), - Map.of("step_1", AGENT_ID) + Map.of("step_1", AGENT_ID), + Collections.emptyMap() ); verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); @@ -79,6 +80,7 @@ public void testNoAgentIdInOutput() throws IOException { inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -101,7 +103,8 @@ public void testDeleteAgentFailure() throws IOException { inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(AGENT_ID, "test"), "workflowId", "nodeId")), - Map.of("step_1", AGENT_ID) + Map.of("step_1", AGENT_ID), + Collections.emptyMap() ); verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java index 45aa4b7ad..0e230c010 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -64,7 +64,8 @@ public void testDeleteConnector() throws IOException, ExecutionException, Interr inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(CONNECTOR_ID, connectorId), "workflowId", "nodeId")), - Map.of("step_1", CONNECTOR_ID) + Map.of("step_1", CONNECTOR_ID), + Collections.emptyMap() ); verify(machineLearningNodeClient).deleteConnector(any(String.class), any()); @@ -79,6 +80,7 @@ public void testNoConnectorIdInOutput() throws IOException { inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -101,7 +103,8 @@ public void testDeleteConnectorFailure() throws IOException { inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(CONNECTOR_ID, "test"), "workflowId", "nodeId")), - Map.of("step_1", CONNECTOR_ID) + Map.of("step_1", CONNECTOR_ID), + Collections.emptyMap() ); verify(machineLearningNodeClient).deleteConnector(any(String.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java index 4c69347cb..e19b70502 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java @@ -64,7 +64,8 @@ public void testDeleteModel() throws IOException, ExecutionException, Interrupte inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, modelId), "workflowId", "nodeId")), - Map.of("step_1", MODEL_ID) + Map.of("step_1", MODEL_ID), + Collections.emptyMap() ); verify(machineLearningNodeClient).deleteModel(any(String.class), any()); @@ -79,6 +80,7 @@ public void testNoModelIdInOutput() throws IOException { inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -101,7 +103,8 @@ public void testDeleteModelFailure() throws IOException { inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, "test"), "workflowId", "nodeId")), - Map.of("step_1", MODEL_ID) + Map.of("step_1", MODEL_ID), + Collections.emptyMap() ); verify(machineLearningNodeClient).deleteModel(any(String.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 342239def..2d6132e91 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -161,6 +161,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -187,6 +188,7 @@ public void testDeployModelFailure() { inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -242,6 +244,7 @@ public void testDeployModelTaskFailure() throws IOException, InterruptedExceptio inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index ea46bc9ec..0df8432c4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -95,6 +95,7 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -122,6 +123,7 @@ public void testRegisterModelGroupFailure() throws IOException { inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -141,6 +143,7 @@ public void testRegisterModelGroupWithNoName() throws IOException { inputDataWithNoName.getNodeId(), inputDataWithNoName, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); diff --git a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java index 171c75272..21141003f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java @@ -23,6 +23,7 @@ public void testNoOpStep() throws IOException { "nodeId", WorkflowData.EMPTY, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index c2f2ba760..9e205cb89 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -74,7 +74,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture f = PlainActionFuture.newFuture(); f.onResponse(new WorkflowData(Map.of("test", "output"), "test-id", "test-node-id")); @@ -119,7 +120,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture future = PlainActionFuture.newFuture(); testThreadPool.schedule( @@ -162,7 +164,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture future = PlainActionFuture.newFuture(); testThreadPool.schedule( @@ -206,7 +209,8 @@ public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, - Map previousNodeInputs + Map previousNodeInputs, + Map params ) { PlainActionFuture f = PlainActionFuture.newFuture(); f.onResponse(WorkflowData.EMPTY); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index 5360a4f12..a2e0807b3 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -109,6 +109,7 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -141,6 +142,7 @@ public void testRegisterAgentFailure() throws IOException { inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index e891d97ba..a223b1bbe 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -173,6 +173,7 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -197,6 +198,7 @@ public void testRegisterLocalCustomModelFailure() { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -246,6 +248,7 @@ public void testRegisterLocalCustomModelTaskFailure() { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -259,6 +262,7 @@ public void testMissingInputs() { "nodeId", new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index 8eb9d7798..7b1f0221e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -167,6 +167,7 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -191,6 +192,7 @@ public void testRegisterLocalPretrainedModelFailure() { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -240,6 +242,7 @@ public void testRegisterLocalPretrainedModelTaskFailure() { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -253,6 +256,7 @@ public void testMissingInputs() { "nodeId", new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index 6ca63b9de..6f77b9b21 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -170,6 +170,7 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -194,6 +195,7 @@ public void testRegisterLocalSparseEncodingModelFailure() { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -243,6 +245,7 @@ public void testRegisterLocalSparseEncodingModelTaskFailure() { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -256,6 +259,7 @@ public void testMissingInputs() { "nodeId", new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 50766efe5..ed26ca641 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -93,6 +93,7 @@ public void testRegisterRemoteModelSuccess() throws Exception { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -140,6 +141,7 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { deployWorkflowData.getNodeId(), deployWorkflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -163,6 +165,7 @@ public void testRegisterRemoteModelFailure() { workflowData.getNodeId(), workflowData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); @@ -177,6 +180,7 @@ public void testMissingInputs() { "nodeId", new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index 0dc4d7960..27bb44c83 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -44,6 +44,7 @@ public void testTool() throws IOException, ExecutionException, InterruptedExcept inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); diff --git a/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java index 6fe1cade2..ca30929e4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java @@ -72,7 +72,8 @@ public void testUndeployModel() throws IOException, ExecutionException, Interrup inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, modelId), "workflowId", "nodeId")), - Map.of("step_1", MODEL_ID) + Map.of("step_1", MODEL_ID), + Collections.emptyMap() ); verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); @@ -87,6 +88,7 @@ public void testNoModelIdInOutput() throws IOException { inputData.getNodeId(), inputData, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap() ); @@ -118,7 +120,8 @@ public void testUndeployModelFailure() throws IOException { inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, "test"), "workflowId", "nodeId")), - Map.of("step_1", MODEL_ID) + Map.of("step_1", MODEL_ID), + Collections.emptyMap() ); verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); From 75f5b253c9bf5ca198a35f2858b0bed5d36fb786 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 15 Feb 2024 23:24:41 -0800 Subject: [PATCH 4/9] Substitute params Signed-off-by: Daniel Widdis --- .../opensearch/flowframework/util/ParseUtils.java | 14 ++++++++++---- .../flowframework/util/ParseUtilsTests.java | 9 +++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index b673a55a1..6192d8e6d 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -310,11 +310,11 @@ public static Map getInputsFromPreviousSteps( Map valueMap = (Map) value; value = valueMap.entrySet() .stream() - .collect(Collectors.toMap(Map.Entry::getKey, e -> conditionallySubstitute(e.getValue(), outputs))); + .collect(Collectors.toMap(Map.Entry::getKey, e -> conditionallySubstitute(e.getValue(), outputs, params))); } else if (value instanceof List) { - value = ((List) value).stream().map(v -> conditionallySubstitute(v, outputs)).collect(Collectors.toList()); + value = ((List) value).stream().map(v -> conditionallySubstitute(v, outputs, params)).collect(Collectors.toList()); } else { - value = conditionallySubstitute(value, outputs); + value = conditionallySubstitute(value, outputs, params); } // Add value to inputs and mark that a required key was present inputs.put(key, value); @@ -338,15 +338,21 @@ public static Map getInputsFromPreviousSteps( return inputs; } - private static Object conditionallySubstitute(Object value, Map outputs) { + private static Object conditionallySubstitute(Object value, Map outputs, Map params) { if (value instanceof String) { Matcher m = SUBSTITUTION_PATTERN.matcher((String) value); if (m.matches()) { + // Try matching a previous step+value pair WorkflowData data = outputs.get(m.group(1)); if (data != null && data.getContent().containsKey(m.group(2))) { return data.getContent().get(m.group(2)); } } + // Replace all params if present + for (Entry e : params.entrySet()) { + String regex = "\\$\\{\\{\\s*" + Pattern.quote(e.getKey()) + "\\s*\\}\\}"; + value = ((String) value).replaceAll(regex, e.getValue()); + } } return value; } diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 4f9bbddc6..5148f9251 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -89,7 +89,8 @@ public void testGetInputsFromPreviousSteps() { Map.entry("param1", 2), Map.entry("content3", "${{step1.output1}}"), Map.entry("nestedMap", Map.of("content4", "${{step3.output3}}")), - Map.entry("nestedList", List.of("${{step4.output4}}")) + Map.entry("nestedList", List.of("${{step4.output4}}")), + Map.entry("content5", "${{pathparam1}} plus ${{pathparam1}} is ${{pathparam2}} but I didn't replace ${{pathparam3}}") ), Map.of("param1", "value1"), "workflowId", @@ -110,8 +111,8 @@ public void testGetInputsFromPreviousSteps() { ); Map previousNodeInputs = Map.of("step2", "output2"); Set requiredKeys = Set.of("param1", "content1"); - Set optionalKeys = Set.of("output1", "output2", "content3", "nestedMap", "nestedList", "no-output"); - Map params = Map.of("param1", "value1"); + Set optionalKeys = Set.of("output1", "output2", "content3", "nestedMap", "nestedList", "no-output", "content5"); + Map params = Map.ofEntries(Map.entry("pathparam1", "one"), Map.entry("pathparam2", "two")); Map inputs = ParseUtils.getInputsFromPreviousSteps( requiredKeys, optionalKeys, @@ -126,7 +127,6 @@ public void testGetInputsFromPreviousSteps() { assertEquals("outputvalue1", inputs.get("output1")); assertEquals("step2outputvalue2", inputs.get("output2")); - // FIXME add a substitution test for params here // Substitutions assertEquals("outputvalue1", inputs.get("content3")); @SuppressWarnings("unchecked") @@ -135,6 +135,7 @@ public void testGetInputsFromPreviousSteps() { @SuppressWarnings("unchecked") List nestedList = (List) inputs.get("nestedList"); assertEquals(List.of("step4outputvalue4"), nestedList); + assertEquals("one plus one is two but I didn't replace ${{pathparam3}}", inputs.get("content5")); assertNull(inputs.get("no-output")); Set missingRequiredKeys = Set.of("not-here"); From 60e1fcfb7eab11fa5c348eec1d318c5974b46fce Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 15 Feb 2024 23:37:39 -0800 Subject: [PATCH 5/9] Add change log Signed-off-by: Daniel Widdis --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3486c6e90..e29d0bc24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.12...2.x) ### Features ### Enhancements +- Substitute REST path parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525)) + ### Bug Fixes ### Infrastructure ### Documentation From a26211f7bb8bfc0fb6e7d53ae0f0452931695a05 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 16 Feb 2024 11:50:05 -0800 Subject: [PATCH 6/9] Improve param consuming checks, add coverage Signed-off-by: Daniel Widdis --- .../rest/RestCreateWorkflowAction.java | 16 +++++-- .../rest/RestProvisionWorkflowAction.java | 10 ++-- .../rest/RestCreateWorkflowActionTests.java | 46 ++++++++++++++++++- 3 files changed, 60 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 595ed4932..bf4943403 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -27,10 +27,10 @@ import org.opensearch.rest.RestRequest; import java.io.IOException; -import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -78,16 +78,19 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String workflowId = request.param(WORKFLOW_ID); String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" }); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); - Map params = Collections.emptyMap(); final List validCreateParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW); // If provisioning, consume all other params and pass to provision transport action - if (provision) { - params = request.params() + Map params = provision + ? request.params() + .keySet() + .stream() + .filter(k -> !validCreateParams.contains(k)) + .collect(Collectors.toMap(Function.identity(), request::param)) + : request.params() .entrySet() .stream() .filter(e -> !validCreateParams.contains(e.getKey())) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { FlowFrameworkException ffe = new FlowFrameworkException( "This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", @@ -98,6 +101,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); } if (!provision && !params.isEmpty()) { + // Consume params and content so custom exception is processed + params.keySet().stream().forEach(request::param); + request.content(); FlowFrameworkException ffe = new FlowFrameworkException( "Only the parameters " + validCreateParams + " are permitted unless the provision parameter is set to true.", RestStatus.BAD_REQUEST diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 7c0fdb3d4..f7d61a114 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -28,10 +28,9 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; -import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; -import static org.opensearch.flowframework.common.CommonValue.VALIDATION; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -72,12 +71,11 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); - final List excludeParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW); Map params = request.params() - .entrySet() + .keySet() .stream() - .filter(e -> !WORKFLOW_ID.equals(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + .filter(k -> !WORKFLOW_ID.equals(k)) + .collect(Collectors.toMap(Function.identity(), request::param)); try { if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index fcdaf5757..1d99cf517 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -10,6 +10,7 @@ import org.opensearch.Version; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -19,6 +20,8 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.transport.WorkflowResponse; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; @@ -30,12 +33,16 @@ import java.util.Locale; import java.util.Map; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class RestCreateWorkflowActionTests extends OpenSearchTestCase { + private String validTemplate; private String invalidTemplate; private RestCreateWorkflowAction createWorkflowRestAction; private String createWorkflowPath; @@ -70,7 +77,8 @@ public void setUp() throws Exception { ); // Invalid template configuration, wrong field name - this.invalidTemplate = template.toJson().replace("use_case", "invalid"); + this.validTemplate = template.toJson(); + this.invalidTemplate = this.validTemplate.replace("use_case", "invalid"); this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting); this.createWorkflowPath = String.format(Locale.ROOT, "%s", WORKFLOW_URI); this.updateWorkflowPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); @@ -92,6 +100,42 @@ public void testRestCreateWorkflowActionRoutes() { } + public void testCreateWorkflowRequestWithParamsAndProvision() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.ofEntries(Map.entry(PROVISION_WORKFLOW, "true"), Map.entry("foo", "bar"))) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("id-123")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.CREATED, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); + } + + public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.of("foo", "bar")) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue( + channel.capturedResponse() + .content() + .utf8ToString() + .contains( + "Only the parameters [workflow_id, validation, provision] are permitted unless the provision parameter is set to true." + ) + ); + } + public void testInvalidCreateWorkflowRequest() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) From d6266856f57a28dd1af792a09e4521596d68432d Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 19 Feb 2024 16:44:15 -0800 Subject: [PATCH 7/9] Allow specifying key-value pairs in body Signed-off-by: Daniel Widdis --- .../rest/RestProvisionWorkflowAction.java | 26 +++++-- .../RestProvisionWorkflowActionTests.java | 67 ++++++++++++++++--- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index f7d61a114..84c518615 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -16,6 +16,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; @@ -31,6 +32,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -71,23 +73,37 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); + // Get any other params from path Map params = request.params() .keySet() .stream() .filter(k -> !WORKFLOW_ID.equals(k)) .collect(Collectors.toMap(Function.identity(), request::param)); try { + // If body is included get any params from body + if (request.hasContent()) { + try (XContentParser parser = request.contentParser()) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String key = parser.currentName(); + if (params.containsKey(key)) { + throw new FlowFrameworkException("Duplicate key " + key, RestStatus.BAD_REQUEST); + } + if (parser.nextToken() != XContentParser.Token.VALUE_STRING) { + throw new FlowFrameworkException("Request body fields must have string values", RestStatus.BAD_REQUEST); + } + params.put(key, parser.text()); + } + } catch (IOException e) { + throw new FlowFrameworkException("Request body parsing failed", RestStatus.BAD_REQUEST); + } + } if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", RestStatus.FORBIDDEN ); } - // Validate content - if (request.hasContent()) { - // BaseRestHandler will give appropriate error message - return channel -> channel.sendResponse(null); - } // Validate params if (workflowId == null) { throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java index 6ddd83d11..fd5cd478d 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -9,10 +9,13 @@ package org.opensearch.flowframework.rest; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.transport.WorkflowResponse; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; @@ -21,8 +24,11 @@ import java.util.List; import java.util.Locale; +import java.util.Map; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -49,7 +55,7 @@ public void testRestProvisionWorkflowActionName() { assertEquals("provision_workflow_action", name); } - public void testRestProvisiionWorkflowActionRoutes() { + public void testRestProvisionWorkflowActionRoutes() { List routes = provisionWorkflowRestAction.routes(); assertEquals(1, routes.size()); assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); @@ -71,20 +77,61 @@ public void testNullWorkflowId() throws Exception { assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); } - public void testInvalidRequestWithContent() { + public void testContentParsing() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.provisionWorkflowPath) - .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .withParams(Map.of("workflow_id", "abc")) + .withContent(new BytesArray("{\"foo\": \"bar\"}"), MediaTypeRegistry.JSON) .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 1); - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { - provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); - }); - assertEquals( - "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_provision] does not support having a body", - ex.getMessage() - ); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("id-123")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.OK, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); + } + + public void testContentParsingDuplicate() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withParams(Map.ofEntries(Map.entry("workflow_id", "abc"), Map.entry("foo", "bar"))) + .withContent(new BytesArray("{\"bar\": \"none\", \"foo\": \"baz\"}"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + // assertEquals("", channel.capturedResponse().content().utf8ToString()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("Duplicate key foo")); + } + + public void testContentParsingBadType() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withParams(Map.of("workflow_id", "abc")) + .withContent(new BytesArray("{\"foo\": 123}"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("Request body fields must have string values")); + } + + public void testContentParsingError() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withContent(new BytesArray("not json"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("Request body parsing failed")); } public void testFeatureFlagNotEnabled() throws Exception { From 3d185311049db25f1c7c9390f2737bbbb4d7feda Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 19 Feb 2024 17:20:40 -0800 Subject: [PATCH 8/9] Update title in change log Signed-off-by: Daniel Widdis --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e29d0bc24..ad2517357 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.12...2.x) ### Features ### Enhancements -- Substitute REST path parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525)) +- Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525)) ### Bug Fixes ### Infrastructure From 098e75fb4a18ade3a3deceb38cb6f14d107c5ece Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 21 Feb 2024 15:21:22 -0800 Subject: [PATCH 9/9] Refactor param and content map generation to a new method Signed-off-by: Daniel Widdis --- .../rest/RestProvisionWorkflowAction.java | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 84c518615..cbd9afd82 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -73,31 +73,8 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); - // Get any other params from path - Map params = request.params() - .keySet() - .stream() - .filter(k -> !WORKFLOW_ID.equals(k)) - .collect(Collectors.toMap(Function.identity(), request::param)); try { - // If body is included get any params from body - if (request.hasContent()) { - try (XContentParser parser = request.contentParser()) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String key = parser.currentName(); - if (params.containsKey(key)) { - throw new FlowFrameworkException("Duplicate key " + key, RestStatus.BAD_REQUEST); - } - if (parser.nextToken() != XContentParser.Token.VALUE_STRING) { - throw new FlowFrameworkException("Request body fields must have string values", RestStatus.BAD_REQUEST); - } - params.put(key, parser.text()); - } - } catch (IOException e) { - throw new FlowFrameworkException("Request body parsing failed", RestStatus.BAD_REQUEST); - } - } + Map params = parseParamsAndContent(request); if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", @@ -132,4 +109,31 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } } + private Map parseParamsAndContent(RestRequest request) { + // Get any other params from path + Map params = request.params() + .keySet() + .stream() + .filter(k -> !WORKFLOW_ID.equals(k)) + .collect(Collectors.toMap(Function.identity(), request::param)); + // If body is included get any params from body + if (request.hasContent()) { + try (XContentParser parser = request.contentParser()) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String key = parser.currentName(); + if (params.containsKey(key)) { + throw new FlowFrameworkException("Duplicate key " + key, RestStatus.BAD_REQUEST); + } + if (parser.nextToken() != XContentParser.Token.VALUE_STRING) { + throw new FlowFrameworkException("Request body fields must have string values", RestStatus.BAD_REQUEST); + } + params.put(key, parser.text()); + } + } catch (IOException e) { + throw new FlowFrameworkException("Request body parsing failed", RestStatus.BAD_REQUEST); + } + } + return params; + } }