Skip to content

Commit

Permalink
Assets: support range requests for alt representations
Browse files Browse the repository at this point in the history
  • Loading branch information
cjao committed Dec 16, 2023
1 parent 6e331d4 commit 37e9912
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Reduced number of assets to upload when submitting a dispatch.
- Support range requests for alternate asset representations

### Operations

Expand Down
74 changes: 45 additions & 29 deletions covalent_dispatcher/_service/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_node_asset(
node_id: int,
key: ElectronAssetKey,
representation: Union[AssetRepresentation, None] = None,
Range: Union[str, None] = Header(default=None, regex=range_regex),
Range: Union[str, None] = Header(default=None, pattern=range_regex),
):
"""Returns an asset for an electron.
Expand Down Expand Up @@ -105,14 +105,20 @@ def get_node_asset(
with workflow_db.session() as session:
asset = node.get_asset(key=key.value, session=session)

# Explicit representation overrides the byte range
if representation is None or ELECTRON_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE:
start_byte = start_byte
end_byte = end_byte
elif representation == AssetRepresentation.string:
start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri)
else:
start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri)
# Handle requests for alt representations. Only assets of type
# TRANSPORTABLE admit dual representations. Translate the byte
# range request to the byte range containing the
# represnetation.
if representation is not None:
if ELECTRON_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE:
raise HTTPException(400, "Asset does not admit alt representations")
elif representation == AssetRepresentation.string:
start_offset, end_offset = _get_tobj_string_offsets(asset.internal_uri)
else:
start_offset, end_offset = _get_tobj_pickle_offsets(asset.internal_uri)

start_byte = min(end_offset, start_offset + start_byte)
end_byte = end_offset if end_byte < 0 else min(end_offset, start_offset + end_byte)

app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}")
generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte)
Expand All @@ -128,7 +134,7 @@ def get_dispatch_asset(
dispatch_id: str,
key: DispatchAssetKey,
representation: Union[AssetRepresentation, None] = None,
Range: Union[str, None] = Header(default=None, regex=range_regex),
Range: Union[str, None] = Header(default=None, pattern=range_regex),
):
"""Returns a dynamic asset for a workflow
Expand Down Expand Up @@ -162,14 +168,20 @@ def get_dispatch_asset(
with workflow_db.session() as session:
asset = result_object.get_asset(key=key.value, session=session)

# Explicit representation overrides the byte range
if representation is None or RESULT_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE:
start_byte = start_byte
end_byte = end_byte
elif representation == AssetRepresentation.string:
start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri)
else:
start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri)
# Handle requests for alt representations. Only assets of type
# TRANSPORTABLE admit dual representations. Translate the byte
# range request to the byte range containing the
# represnetation.
if representation is not None:
if RESULT_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE:
raise HTTPException(400, "Asset does not admit alt representations")
elif representation == AssetRepresentation.string:
start_offset, end_offset = _get_tobj_string_offsets(asset.internal_uri)
else:
start_offset, end_offset = _get_tobj_pickle_offsets(asset.internal_uri)

start_byte = min(end_offset, start_offset + start_byte)
end_byte = end_offset if end_byte < 0 else min(end_offset, start_offset + end_byte)

app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}")
generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte)
Expand All @@ -184,7 +196,7 @@ def get_lattice_asset(
dispatch_id: str,
key: LatticeAssetKey,
representation: Union[AssetRepresentation, None] = None,
Range: Union[str, None] = Header(default=None, regex=range_regex),
Range: Union[str, None] = Header(default=None, pattern=range_regex),
):
"""Returns a static asset for a workflow
Expand All @@ -193,8 +205,6 @@ def get_lattice_asset(
key: The name of the asset
representation: (optional) the representation ("string" or "pickle") of a `TransportableObject`
range: (optional) range request header
If `representation` is specified, it will override the range request.
"""
start_byte = 0
end_byte = -1
Expand All @@ -218,14 +228,20 @@ def get_lattice_asset(
with workflow_db.session() as session:
asset = result_object.lattice.get_asset(key=key.value, session=session)

# Explicit representation overrides the byte range
if representation is None or LATTICE_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE:
start_byte = start_byte
end_byte = end_byte
elif representation == AssetRepresentation.string:
start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri)
else:
start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri)
# Handle requests for alt representations. Only assets of type
# TRANSPORTABLE admit dual representations. Translate the byte
# range request to the byte range containing the
# represnetation.
if representation is not None:
if LATTICE_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE:
raise HTTPException(400, "Asset does not admit alt representations")
elif representation == AssetRepresentation.string:
start_offset, end_offset = _get_tobj_string_offsets(asset.internal_uri)
else:
start_offset, end_offset = _get_tobj_pickle_offsets(asset.internal_uri)

start_byte = min(end_offset, start_offset + start_byte)
end_byte = end_offset if end_byte < 0 else min(end_offset, start_offset + end_byte)

app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}")
generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte)
Expand Down
115 changes: 106 additions & 9 deletions tests/covalent_dispatcher_tests/_service/assets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,20 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in
assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0]


@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)])
@pytest.mark.parametrize(
"rep,start_offset,start_byte_req,end_offset,end_byte_req,",
[("string", 0, 0, 6, 3), ("string", 0, 2, 6, ""), ("object", 6, 0, 12, 3)],
)
def test_get_node_asset_rep(
mocker, client, test_db, mock_result_object, rep, start_byte, end_byte
mocker,
client,
test_db,
mock_result_object,
rep,
start_offset,
start_byte_req,
end_offset,
end_byte_req,
):
"""
Test get node asset
Expand Down Expand Up @@ -221,16 +232,39 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in
)

params = {"representation": rep}
headers = {"Range": f"bytes={start_byte_req}-{end_byte_req}"}
mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status")

resp = client.get(
f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", params=params
f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}",
headers=headers,
params=params,
)
start_byte = min(start_offset + start_byte_req, end_offset)
end_byte = end_offset if not end_byte_req else min(start_offset + end_byte_req, end_offset)

assert resp.text == test_str[start_byte:end_byte]
assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0]


def test_get_node_asset_rep_nontransportable(mocker, mock_result_object, client, test_db):
"""Test alt representations of assets other than TransportableObject"""

mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db)
mocker.patch(
"covalent_dispatcher._service.assets.get_cached_result_object",
return_value=mock_result_object,
)
node_id = 0
dispatch_id = "test_get_node_asset_rep_nontransportable"
key = "stdout"
params = {"representation": "string"}
resp = client.get(
f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", params=params
)
assert resp.status_code == 400


def test_get_node_asset_bad_dispatch_id(mocker, client):
"""
Test get node asset
Expand Down Expand Up @@ -319,9 +353,20 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in
assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0]


@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)])
@pytest.mark.parametrize(
"rep,start_offset,start_byte_req,end_offset,end_byte_req,",
[("string", 0, 0, 6, 3), ("string", 0, 2, 6, ""), ("object", 6, 0, 12, 3)],
)
def test_get_lattice_asset_rep(
mocker, client, test_db, mock_result_object, rep, start_byte, end_byte
mocker,
client,
test_db,
mock_result_object,
rep,
start_offset,
start_byte_req,
end_offset,
end_byte_req,
):
"""
Test get lattice asset
Expand Down Expand Up @@ -361,13 +406,34 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in
mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status")

params = {"representation": rep}
headers = {"Range": f"bytes={start_byte_req}-{end_byte_req}"}

resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", params=params)
resp = client.get(
f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", headers=headers, params=params
)

start_byte = min(start_offset + start_byte_req, end_offset)
end_byte = end_offset if not end_byte_req else min(start_offset + end_byte_req, end_offset)

assert resp.text == test_str[start_byte:end_byte]
assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0]


def test_get_lattice_asset_rep_nontransportable(mocker, mock_result_object, client, test_db):
"""Test alt representations of assets other than TransportableObject"""

mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db)
mocker.patch(
"covalent_dispatcher._service.assets.get_cached_result_object",
return_value=mock_result_object,
)
dispatch_id = "test_get_node_asset_rep_nontransportable"
key = "workflow_function_string"
params = {"representation": "string"}
resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", params=params)
assert resp.status_code == 400


def test_get_lattice_asset_bad_dispatch_id(mocker, client):
"""
Test get lattice asset
Expand Down Expand Up @@ -458,9 +524,20 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in
assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0]


@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)])
@pytest.mark.parametrize(
"rep,start_offset,start_byte_req,end_offset,end_byte_req,",
[("string", 0, 0, 6, 3), ("string", 0, 2, 6, ""), ("object", 6, 0, 12, 3)],
)
def test_get_dispatch_asset_rep(
mocker, client, test_db, mock_result_object, rep, start_byte, end_byte
mocker,
client,
test_db,
mock_result_object,
rep,
start_offset,
start_byte_req,
end_offset,
end_byte_req,
):
"""
Test get dispatch asset
Expand Down Expand Up @@ -500,13 +577,33 @@ def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: in
mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status")

params = {"representation": rep}
headers = {"Range": f"bytes={start_byte_req}-{end_byte_req}"}

resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", params=params)
resp = client.get(
f"/api/v2/dispatches/{dispatch_id}/assets/{key}", headers=headers, params=params
)
start_byte = min(start_offset + start_byte_req, end_offset)
end_byte = end_offset if not end_byte_req else min(start_offset + end_byte_req, end_offset)

assert resp.text == test_str[start_byte:end_byte]
assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0]


def test_get_dispatch_asset_rep_nontransportable(mocker, mock_result_object, client, test_db):
"""Test alt representations of assets other than TransportableObject"""

mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db)
mocker.patch(
"covalent_dispatcher._service.assets.get_cached_result_object",
return_value=mock_result_object,
)
dispatch_id = "test_get_node_asset_rep_nontransportable"
key = "error"
params = {"representation": "string"}
resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", params=params)
assert resp.status_code == 400


def test_get_dispatch_asset_bad_dispatch_id(mocker, client):
"""
Test get dispatch asset
Expand Down

0 comments on commit 37e9912

Please sign in to comment.