diff --git a/covalent/_api/apiclient.py b/covalent/_api/apiclient.py index c4c2a5492..d3be6bd4a 100644 --- a/covalent/_api/apiclient.py +++ b/covalent/_api/apiclient.py @@ -33,7 +33,7 @@ def __init__(self, dispatcher_addr: str, adapter: HTTPAdapter = None, auto_raise self.adapter = adapter self.auto_raise = auto_raise - def prepare_headers(self, **kwargs): + def prepare_headers(self, kwargs): extra_headers = CovalentAPIClient.get_extra_headers() headers = kwargs.get("headers", {}) if headers: @@ -42,7 +42,7 @@ def prepare_headers(self, **kwargs): return headers def get(self, endpoint: str, **kwargs): - headers = self.prepare_headers(**kwargs) + headers = self.prepare_headers(kwargs) url = self.dispatcher_addr + endpoint try: with requests.Session() as session: @@ -62,7 +62,7 @@ def get(self, endpoint: str, **kwargs): return r def put(self, endpoint: str, **kwargs): - headers = self.prepare_headers() + headers = self.prepare_headers(kwargs) url = self.dispatcher_addr + endpoint try: with requests.Session() as session: @@ -81,7 +81,7 @@ def put(self, endpoint: str, **kwargs): return r def post(self, endpoint: str, **kwargs): - headers = self.prepare_headers() + headers = self.prepare_headers(kwargs) url = self.dispatcher_addr + endpoint try: with requests.Session() as session: @@ -100,7 +100,7 @@ def post(self, endpoint: str, **kwargs): return r def delete(self, endpoint: str, **kwargs): - headers = self.prepare_headers() + headers = self.prepare_headers(kwargs) url = self.dispatcher_addr + endpoint try: with requests.Session() as session: diff --git a/covalent/_dispatcher_plugins/local.py b/covalent/_dispatcher_plugins/local.py index 8760cec96..f5ad26bdd 100644 --- a/covalent/_dispatcher_plugins/local.py +++ b/covalent/_dispatcher_plugins/local.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import tempfile from copy import deepcopy from functools import wraps @@ -616,14 +617,16 @@ def _upload_asset(local_uri, remote_uri): local_path = local_uri with open(local_path, "rb") as reader: - app_log.debug(f"uploading to {remote_uri}") + content_length = os.path.getsize(local_path) f = furl(remote_uri) scheme = f.scheme host = f.host port = f.port dispatcher_addr = f"{scheme}://{host}:{port}" - endpoint = str(f.path) + endpoint = f"{str(f.path)}?{str(f.query)}" api_client = APIClient(dispatcher_addr) - - r = api_client.put(endpoint, data=reader) + if content_length == 0: + r = api_client.put(endpoint, headers={"Content-Length": "0"}, data=reader.read()) + else: + r = api_client.put(endpoint, data=reader) r.raise_for_status() diff --git a/covalent/_workflow/electron.py b/covalent/_workflow/electron.py index 12f18cbf5..b38ea3d5e 100644 --- a/covalent/_workflow/electron.py +++ b/covalent/_workflow/electron.py @@ -915,6 +915,7 @@ def _build_sublattice_graph(sub: Lattice, json_parent_metadata: str, *args, **kw return recv_manifest.model_dump_json() except Exception as ex: + if os.environ.get("COVALENT_DISABLE_LEGACY_SUBLATTICES") == "1": + raise # Fall back to legacy sublattice handling - print("Falling back to legacy sublattice handling") return sub.serialize_to_json() diff --git a/tests/covalent_tests/api/apiclient_test.py b/tests/covalent_tests/api/apiclient_test.py new file mode 100644 index 000000000..20d22cb5a --- /dev/null +++ b/tests/covalent_tests/api/apiclient_test.py @@ -0,0 +1,44 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Unit tests for the API client""" + +import json +from unittest.mock import MagicMock + +import pytest + +from covalent._api.apiclient import CovalentAPIClient + + +@pytest.fixture +def mock_session(): + sess = MagicMock() + + +def test_header_injection(mocker): + extra_headers = {"x-custom-header": "value"} + headers = {"Content-Length": "128"} + expected_headers = headers.copy() + expected_headers.update(extra_headers) + mock_session = MagicMock() + environ = {"COVALENT_EXTRA_HEADERS": json.dumps(extra_headers)} + mocker.patch("os.environ", environ) + mocker.patch("requests.Session.__enter__", return_value=mock_session) + + CovalentAPIClient("http://localhost").post("/docs", headers=headers) + mock_session.post.assert_called_with("http://localhost/docs", headers=expected_headers) diff --git a/tests/covalent_tests/workflow/electron_test.py b/tests/covalent_tests/workflow/electron_test.py index a3db76bb7..8925d9bae 100644 --- a/tests/covalent_tests/workflow/electron_test.py +++ b/tests/covalent_tests/workflow/electron_test.py @@ -201,6 +201,52 @@ def workflow(x): assert parent_metadata[k] == lattice.metadata[k] +def test_build_sublattice_graph_prevent_fallback(mocker): + """ + Test preventing falling back to monolithic sublattice dispatch. + """ + dispatch_id = "test_build_sublattice_graph_prevent_fallback" + + @ct.electron + def task(x): + return x + + @ct.lattice + def workflow(x): + return task(x) + + parent_metadata = { + "executor": "parent_executor", + "executor_data": {}, + "workflow_executor": "my_postprocessor", + "workflow_executor_data": {}, + "hooks": { + "deps": {"bash": None, "pip": None}, + "call_before": [], + "call_after": [], + }, + "triggers": "mock-trigger", + "qelectron_data_exists": False, + "results_dir": None, + } + + # Omit the required environment variables + mock_environ = {"COVALENT_DISABLE_LEGACY_SUBLATTICES": "1"} + + mock_reg = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_manifest", + ) + + mock_upload_assets = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.upload_assets", + ) + + mocker.patch("os.environ", mock_environ) + + with pytest.raises(Exception): + json_lattice = _build_sublattice_graph(workflow, json.dumps(parent_metadata), 1) + + def test_wait_for_building(): """Test to check whether the graph is built correctly with `wait_for`."""