diff --git a/CHANGELOG.md b/CHANGELOG.md index 594c1772e..dd0b08c73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Reduced number of assets to upload when submitting a dispatch. +- Support range requests for alternate asset representations ### Operations diff --git a/covalent_dispatcher/_service/assets.py b/covalent_dispatcher/_service/assets.py index 0664e5058..4f7d824b8 100644 --- a/covalent_dispatcher/_service/assets.py +++ b/covalent_dispatcher/_service/assets.py @@ -68,7 +68,7 @@ def get_node_asset( node_id: int, key: ElectronAssetKey, representation: Union[AssetRepresentation, None] = None, - Range: Union[str, None] = Header(default=None, regex=range_regex), + Range: Union[str, None] = Header(default=None, pattern=range_regex), ): """Returns an asset for an electron. @@ -105,14 +105,20 @@ def get_node_asset( with workflow_db.session() as session: asset = node.get_asset(key=key.value, session=session) - # Explicit representation overrides the byte range - if representation is None or ELECTRON_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: - start_byte = start_byte - end_byte = end_byte - elif representation == AssetRepresentation.string: - start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) - else: - start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) + # Handle requests for alt representations. Only assets of type + # TRANSPORTABLE admit dual representations. Translate the byte + # range request to the byte range containing the + # represnetation. + if representation is not None: + if ELECTRON_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: + raise HTTPException(400, "Asset does not admit alt representations") + elif representation == AssetRepresentation.string: + start_offset, end_offset = _get_tobj_string_offsets(asset.internal_uri) + else: + start_offset, end_offset = _get_tobj_pickle_offsets(asset.internal_uri) + + start_byte = min(end_offset, start_offset + start_byte) + end_byte = end_offset if end_byte < 0 else min(end_offset, start_offset + end_byte) app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) @@ -128,7 +134,7 @@ def get_dispatch_asset( dispatch_id: str, key: DispatchAssetKey, representation: Union[AssetRepresentation, None] = None, - Range: Union[str, None] = Header(default=None, regex=range_regex), + Range: Union[str, None] = Header(default=None, pattern=range_regex), ): """Returns a dynamic asset for a workflow @@ -162,14 +168,20 @@ def get_dispatch_asset( with workflow_db.session() as session: asset = result_object.get_asset(key=key.value, session=session) - # Explicit representation overrides the byte range - if representation is None or RESULT_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: - start_byte = start_byte - end_byte = end_byte - elif representation == AssetRepresentation.string: - start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) - else: - start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) + # Handle requests for alt representations. Only assets of type + # TRANSPORTABLE admit dual representations. Translate the byte + # range request to the byte range containing the + # represnetation. + if representation is not None: + if RESULT_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: + raise HTTPException(400, "Asset does not admit alt representations") + elif representation == AssetRepresentation.string: + start_offset, end_offset = _get_tobj_string_offsets(asset.internal_uri) + else: + start_offset, end_offset = _get_tobj_pickle_offsets(asset.internal_uri) + + start_byte = min(end_offset, start_offset + start_byte) + end_byte = end_offset if end_byte < 0 else min(end_offset, start_offset + end_byte) app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) @@ -184,7 +196,7 @@ def get_lattice_asset( dispatch_id: str, key: LatticeAssetKey, representation: Union[AssetRepresentation, None] = None, - Range: Union[str, None] = Header(default=None, regex=range_regex), + Range: Union[str, None] = Header(default=None, pattern=range_regex), ): """Returns a static asset for a workflow @@ -193,8 +205,6 @@ def get_lattice_asset( key: The name of the asset representation: (optional) the representation ("string" or "pickle") of a `TransportableObject` range: (optional) range request header - - If `representation` is specified, it will override the range request. """ start_byte = 0 end_byte = -1 @@ -218,14 +228,20 @@ def get_lattice_asset( with workflow_db.session() as session: asset = result_object.lattice.get_asset(key=key.value, session=session) - # Explicit representation overrides the byte range - if representation is None or LATTICE_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: - start_byte = start_byte - end_byte = end_byte - elif representation == AssetRepresentation.string: - start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) - else: - start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) + # Handle requests for alt representations. Only assets of type + # TRANSPORTABLE admit dual representations. Translate the byte + # range request to the byte range containing the + # represnetation. + if representation is not None: + if LATTICE_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: + raise HTTPException(400, "Asset does not admit alt representations") + elif representation == AssetRepresentation.string: + start_offset, end_offset = _get_tobj_string_offsets(asset.internal_uri) + else: + start_offset, end_offset = _get_tobj_pickle_offsets(asset.internal_uri) + + start_byte = min(end_offset, start_offset + start_byte) + end_byte = end_offset if end_byte < 0 else min(end_offset, start_offset + end_byte) app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) diff --git a/tests/covalent_dispatcher_tests/_service/assets_test.py b/tests/covalent_dispatcher_tests/_service/assets_test.py index 5f704ca43..aab8e3b37 100644 --- a/tests/covalent_dispatcher_tests/_service/assets_test.py +++ b/tests/covalent_dispatcher_tests/_service/assets_test.py @@ -179,9 +179,20 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] -@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) +@pytest.mark.parametrize( + "rep,start_offset,start_byte_req,end_offset,end_byte_req,", + [("string", 0, 0, 6, 3), ("string", 0, 2, 6, ""), ("object", 6, 0, 12, 3)], +) def test_get_node_asset_rep( - mocker, client, test_db, mock_result_object, rep, start_byte, end_byte + mocker, + client, + test_db, + mock_result_object, + rep, + start_offset, + start_byte_req, + end_offset, + end_byte_req, ): """ Test get node asset @@ -221,16 +232,39 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in ) params = {"representation": rep} + headers = {"Range": f"bytes={start_byte_req}-{end_byte_req}"} mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") resp = client.get( - f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", params=params + f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", + headers=headers, + params=params, ) + start_byte = min(start_offset + start_byte_req, end_offset) + end_byte = end_offset if not end_byte_req else min(start_offset + end_byte_req, end_offset) assert resp.text == test_str[start_byte:end_byte] assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] +def test_get_node_asset_rep_nontransportable(mocker, mock_result_object, client, test_db): + """Test alt representations of assets other than TransportableObject""" + + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + return_value=mock_result_object, + ) + node_id = 0 + dispatch_id = "test_get_node_asset_rep_nontransportable" + key = "stdout" + params = {"representation": "string"} + resp = client.get( + f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", params=params + ) + assert resp.status_code == 400 + + def test_get_node_asset_bad_dispatch_id(mocker, client): """ Test get node asset @@ -319,9 +353,20 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] -@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) +@pytest.mark.parametrize( + "rep,start_offset,start_byte_req,end_offset,end_byte_req,", + [("string", 0, 0, 6, 3), ("string", 0, 2, 6, ""), ("object", 6, 0, 12, 3)], +) def test_get_lattice_asset_rep( - mocker, client, test_db, mock_result_object, rep, start_byte, end_byte + mocker, + client, + test_db, + mock_result_object, + rep, + start_offset, + start_byte_req, + end_offset, + end_byte_req, ): """ Test get lattice asset @@ -361,13 +406,34 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") params = {"representation": rep} + headers = {"Range": f"bytes={start_byte_req}-{end_byte_req}"} - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", params=params) + resp = client.get( + f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", headers=headers, params=params + ) + + start_byte = min(start_offset + start_byte_req, end_offset) + end_byte = end_offset if not end_byte_req else min(start_offset + end_byte_req, end_offset) assert resp.text == test_str[start_byte:end_byte] assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] +def test_get_lattice_asset_rep_nontransportable(mocker, mock_result_object, client, test_db): + """Test alt representations of assets other than TransportableObject""" + + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + return_value=mock_result_object, + ) + dispatch_id = "test_get_node_asset_rep_nontransportable" + key = "workflow_function_string" + params = {"representation": "string"} + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", params=params) + assert resp.status_code == 400 + + def test_get_lattice_asset_bad_dispatch_id(mocker, client): """ Test get lattice asset @@ -458,9 +524,20 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] -@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) +@pytest.mark.parametrize( + "rep,start_offset,start_byte_req,end_offset,end_byte_req,", + [("string", 0, 0, 6, 3), ("string", 0, 2, 6, ""), ("object", 6, 0, 12, 3)], +) def test_get_dispatch_asset_rep( - mocker, client, test_db, mock_result_object, rep, start_byte, end_byte + mocker, + client, + test_db, + mock_result_object, + rep, + start_offset, + start_byte_req, + end_offset, + end_byte_req, ): """ Test get dispatch asset @@ -500,13 +577,33 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") params = {"representation": rep} + headers = {"Range": f"bytes={start_byte_req}-{end_byte_req}"} - resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", params=params) + resp = client.get( + f"/api/v2/dispatches/{dispatch_id}/assets/{key}", headers=headers, params=params + ) + start_byte = min(start_offset + start_byte_req, end_offset) + end_byte = end_offset if not end_byte_req else min(start_offset + end_byte_req, end_offset) assert resp.text == test_str[start_byte:end_byte] assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] +def test_get_dispatch_asset_rep_nontransportable(mocker, mock_result_object, client, test_db): + """Test alt representations of assets other than TransportableObject""" + + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + return_value=mock_result_object, + ) + dispatch_id = "test_get_node_asset_rep_nontransportable" + key = "error" + params = {"representation": "string"} + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", params=params) + assert resp.status_code == 400 + + def test_get_dispatch_asset_bad_dispatch_id(mocker, client): """ Test get dispatch asset