Skip to content

Commit

Permalink
download hook (#1571)
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksanderWWW authored Nov 21, 2023
1 parent 4eae0ab commit 4dda93b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## [UNRELEASED] 1.8.5
## [UNRELEASED] neptune 1.8.5

### Changes
- Enabled hooks for internal downloading functions used by the hosted backend ([#1571](https://github.com/neptune-ai/neptune-client/pull/1571))
- Added timestamp of operation put to disk queue ([#1569](https://github.com/neptune-ai/neptune-client/pull/1569))


Expand Down
24 changes: 21 additions & 3 deletions src/neptune/internal/backends/hosted_file_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from io import BytesIO
from typing import (
AnyStr,
Callable,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -420,6 +421,9 @@ def download_file_attribute(
container_id: str,
attribute: str,
destination: Optional[str] = None,
pre_download_hook: Callable[[int], None] = lambda x: None,
download_iter_hook: Callable[[int], None] = lambda x: None,
post_download_hook: Callable[[], None] = lambda: None,
):
url = build_operation_url(
swagger_client.swagger_spec.api_url,
Expand All @@ -431,13 +435,16 @@ def download_file_attribute(
headers={"Accept": "application/octet-stream"},
query_params={"experimentId": container_id, "attribute": attribute},
)
_store_response_as_file(response, destination)
_store_response_as_file(response, destination, pre_download_hook, download_iter_hook, post_download_hook)


def download_file_set_attribute(
swagger_client: SwaggerClientWrapper,
download_id: str,
destination: Optional[str] = None,
pre_download_hook: Callable[[int], None] = lambda x: None,
download_iter_hook: Callable[[int], None] = lambda x: None,
post_download_hook: Callable[[], None] = lambda: None,
):
download_url: Optional[str] = _get_download_url(swagger_client, download_id)
next_sleep = 0.5
Expand All @@ -451,7 +458,7 @@ def download_file_set_attribute(
url=download_url,
headers={"Accept": "application/zip"},
)
_store_response_as_file(response, destination)
_store_response_as_file(response, destination, pre_download_hook, download_iter_hook, post_download_hook)


def _get_download_url(swagger_client: SwaggerClientWrapper, download_id: str):
Expand All @@ -460,18 +467,29 @@ def _get_download_url(swagger_client: SwaggerClientWrapper, download_id: str):
return download_request.downloadUrl


def _store_response_as_file(response: Response, destination: Optional[str] = None):
def _store_response_as_file(
response: Response,
destination: Optional[str] = None,
pre_download_hook: Callable[[int], None] = lambda x: None,
download_iter_hook: Callable[[int], None] = lambda x: None,
post_download_hook: Callable[[], None] = lambda: None,
) -> None:
if destination is None:
target_file = _get_content_disposition_filename(response)
elif os.path.isdir(destination):
target_file = os.path.join(destination, _get_content_disposition_filename(response))
else:
target_file = destination

total_size = int(response.headers.get("content-length", 0))
pre_download_hook(total_size)
with response:
with open(target_file, "wb") as f:
for chunk in response.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
download_iter_hook(len(chunk) if chunk else 0)
post_download_hook()


def _get_content_disposition_filename(response: Response) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,19 @@ def test_download_file_attribute(self, download_raw, store_response_mock):
swagger_mock = self._get_swagger_mock()
exp_uuid = str(uuid.uuid4())

pre_download_hook = MagicMock()
download_iter_hook = MagicMock()
post_download_hook = MagicMock()

# when
download_file_attribute(
swagger_client=swagger_mock,
container_id=exp_uuid,
attribute="some/attribute",
destination=None,
pre_download_hook=pre_download_hook,
download_iter_hook=download_iter_hook,
post_download_hook=post_download_hook,
)

# then
Expand All @@ -117,7 +124,9 @@ def test_download_file_attribute(self, download_raw, store_response_mock):
headers={"Accept": "application/octet-stream"},
query_params={"experimentId": str(exp_uuid), "attribute": "some/attribute"},
)
store_response_mock.assert_called_once_with(download_raw.return_value, None)
store_response_mock.assert_called_once_with(
download_raw.return_value, None, pre_download_hook, download_iter_hook, post_download_hook
)

@patch("neptune.internal.backends.hosted_file_operations._store_response_as_file")
@patch("neptune.internal.backends.hosted_file_operations._download_raw_data")
Expand All @@ -130,16 +139,29 @@ def test_download_file_set_attribute(self, download_raw, store_response_mock):
swagger_mock = self._get_swagger_mock()
download_id = str(uuid.uuid4())

pre_download_hook = MagicMock()
download_iter_hook = MagicMock()
post_download_hook = MagicMock()

# when
download_file_set_attribute(swagger_client=swagger_mock, download_id=download_id, destination=None)
download_file_set_attribute(
swagger_client=swagger_mock,
download_id=download_id,
destination=None,
pre_download_hook=pre_download_hook,
download_iter_hook=download_iter_hook,
post_download_hook=post_download_hook,
)

# then
download_raw.assert_called_once_with(
http_client=swagger_mock.swagger_spec.http_client,
url="some_url",
headers={"Accept": "application/zip"},
)
store_response_mock.assert_called_once_with(download_raw.return_value, None)
store_response_mock.assert_called_once_with(
download_raw.return_value, None, pre_download_hook, download_iter_hook, post_download_hook
)


class TestNewUploadFileOperations(HostedFileOperationsHelper, BackendTestMixin):
Expand Down

0 comments on commit 4dda93b

Please sign in to comment.