From 4dda93b651784428278b43ba33869be2be3fd3dc Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Tue, 21 Nov 2023 12:39:14 +0100 Subject: [PATCH] download hook (#1571) --- CHANGELOG.md | 3 +- .../backends/hosted_file_operations.py | 24 ++++++++++++++-- .../backends/test_hosted_file_operations.py | 28 +++++++++++++++++-- 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6f1fd4b5..7d3fc6444 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ -## [UNRELEASED] 1.8.5 +## [UNRELEASED] neptune 1.8.5 ### Changes +- Enabled hooks for internal downloading functions used by the hosted backend ([#1571](https://github.com/neptune-ai/neptune-client/pull/1571)) - Added timestamp of operation put to disk queue ([#1569](https://github.com/neptune-ai/neptune-client/pull/1569)) diff --git a/src/neptune/internal/backends/hosted_file_operations.py b/src/neptune/internal/backends/hosted_file_operations.py index 5df6045d9..1c1034734 100644 --- a/src/neptune/internal/backends/hosted_file_operations.py +++ b/src/neptune/internal/backends/hosted_file_operations.py @@ -28,6 +28,7 @@ from io import BytesIO from typing import ( AnyStr, + Callable, Dict, Iterable, List, @@ -420,6 +421,9 @@ def download_file_attribute( container_id: str, attribute: str, destination: Optional[str] = None, + pre_download_hook: Callable[[int], None] = lambda x: None, + download_iter_hook: Callable[[int], None] = lambda x: None, + post_download_hook: Callable[[], None] = lambda: None, ): url = build_operation_url( swagger_client.swagger_spec.api_url, @@ -431,13 +435,16 @@ def download_file_attribute( headers={"Accept": "application/octet-stream"}, query_params={"experimentId": container_id, "attribute": attribute}, ) - _store_response_as_file(response, destination) + _store_response_as_file(response, destination, pre_download_hook, download_iter_hook, post_download_hook) def download_file_set_attribute( swagger_client: SwaggerClientWrapper, download_id: str, destination: Optional[str] = None, + pre_download_hook: Callable[[int], None] = lambda x: None, + download_iter_hook: Callable[[int], None] = lambda x: None, + post_download_hook: Callable[[], None] = lambda: None, ): download_url: Optional[str] = _get_download_url(swagger_client, download_id) next_sleep = 0.5 @@ -451,7 +458,7 @@ def download_file_set_attribute( url=download_url, headers={"Accept": "application/zip"}, ) - _store_response_as_file(response, destination) + _store_response_as_file(response, destination, pre_download_hook, download_iter_hook, post_download_hook) def _get_download_url(swagger_client: SwaggerClientWrapper, download_id: str): @@ -460,18 +467,29 @@ def _get_download_url(swagger_client: SwaggerClientWrapper, download_id: str): return download_request.downloadUrl -def _store_response_as_file(response: Response, destination: Optional[str] = None): +def _store_response_as_file( + response: Response, + destination: Optional[str] = None, + pre_download_hook: Callable[[int], None] = lambda x: None, + download_iter_hook: Callable[[int], None] = lambda x: None, + post_download_hook: Callable[[], None] = lambda: None, +) -> None: if destination is None: target_file = _get_content_disposition_filename(response) elif os.path.isdir(destination): target_file = os.path.join(destination, _get_content_disposition_filename(response)) else: target_file = destination + + total_size = int(response.headers.get("content-length", 0)) + pre_download_hook(total_size) with response: with open(target_file, "wb") as f: for chunk in response.iter_content(chunk_size=1024 * 1024): if chunk: f.write(chunk) + download_iter_hook(len(chunk) if chunk else 0) + post_download_hook() def _get_content_disposition_filename(response: Response) -> str: diff --git a/tests/unit/neptune/new/internal/backends/test_hosted_file_operations.py b/tests/unit/neptune/new/internal/backends/test_hosted_file_operations.py index cc2443238..308971ccb 100644 --- a/tests/unit/neptune/new/internal/backends/test_hosted_file_operations.py +++ b/tests/unit/neptune/new/internal/backends/test_hosted_file_operations.py @@ -102,12 +102,19 @@ def test_download_file_attribute(self, download_raw, store_response_mock): swagger_mock = self._get_swagger_mock() exp_uuid = str(uuid.uuid4()) + pre_download_hook = MagicMock() + download_iter_hook = MagicMock() + post_download_hook = MagicMock() + # when download_file_attribute( swagger_client=swagger_mock, container_id=exp_uuid, attribute="some/attribute", destination=None, + pre_download_hook=pre_download_hook, + download_iter_hook=download_iter_hook, + post_download_hook=post_download_hook, ) # then @@ -117,7 +124,9 @@ def test_download_file_attribute(self, download_raw, store_response_mock): headers={"Accept": "application/octet-stream"}, query_params={"experimentId": str(exp_uuid), "attribute": "some/attribute"}, ) - store_response_mock.assert_called_once_with(download_raw.return_value, None) + store_response_mock.assert_called_once_with( + download_raw.return_value, None, pre_download_hook, download_iter_hook, post_download_hook + ) @patch("neptune.internal.backends.hosted_file_operations._store_response_as_file") @patch("neptune.internal.backends.hosted_file_operations._download_raw_data") @@ -130,8 +139,19 @@ def test_download_file_set_attribute(self, download_raw, store_response_mock): swagger_mock = self._get_swagger_mock() download_id = str(uuid.uuid4()) + pre_download_hook = MagicMock() + download_iter_hook = MagicMock() + post_download_hook = MagicMock() + # when - download_file_set_attribute(swagger_client=swagger_mock, download_id=download_id, destination=None) + download_file_set_attribute( + swagger_client=swagger_mock, + download_id=download_id, + destination=None, + pre_download_hook=pre_download_hook, + download_iter_hook=download_iter_hook, + post_download_hook=post_download_hook, + ) # then download_raw.assert_called_once_with( @@ -139,7 +159,9 @@ def test_download_file_set_attribute(self, download_raw, store_response_mock): url="some_url", headers={"Accept": "application/zip"}, ) - store_response_mock.assert_called_once_with(download_raw.return_value, None) + store_response_mock.assert_called_once_with( + download_raw.return_value, None, pre_download_hook, download_iter_hook, post_download_hook + ) class TestNewUploadFileOperations(HostedFileOperationsHelper, BackendTestMixin):