From 3bf3569c6a29cd95b8c8141180e61672299ce3f3 Mon Sep 17 00:00:00 2001 From: Karan Jariwala Date: Mon, 5 Sep 2022 15:06:15 -0700 Subject: [PATCH] Improved comments and improved test code (#1502) --- composer/datasets/streaming/download.py | 5 +---- tests/datasets/test_streaming.py | 10 +++++----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/composer/datasets/streaming/download.py b/composer/datasets/streaming/download.py index 1652759a72..37241a545a 100644 --- a/composer/datasets/streaming/download.py +++ b/composer/datasets/streaming/download.py @@ -37,7 +37,7 @@ def get_object_store(remote: str) -> ObjectStore: elif remote.startswith('sftp://'): return _get_sftp_object_store(remote) else: - raise ValueError('unsupported upload scheme') + raise ValueError('unsupported download scheme') def _get_s3_object_store(remote: str) -> S3ObjectStore: @@ -62,9 +62,6 @@ def _get_sftp_object_store(remote: str) -> SFTPObjectStore: return object_store -__all__ = ['download_or_wait'] - - def download_from_local(remote: str, local: str) -> None: """Download a file from remote to local. diff --git a/tests/datasets/test_streaming.py b/tests/datasets/test_streaming.py index b5eba85075..a612f9d4e8 100644 --- a/tests/datasets/test_streaming.py +++ b/tests/datasets/test_streaming.py @@ -19,8 +19,8 @@ @pytest.fixture def remote_local(tmp_path: pathlib.Path) -> Tuple[str, str]: - remote = tmp_path / 'remote' - local = tmp_path / 'local' + remote = tmp_path.joinpath('remote') + local = tmp_path.joinpath('local') remote.mkdir() local.mkdir() return str(remote), str(local) @@ -28,9 +28,9 @@ def remote_local(tmp_path: pathlib.Path) -> Tuple[str, str]: @pytest.fixture def compressed_remote_local(tmp_path: pathlib.Path) -> Tuple[str, str, str]: - compressed = tmp_path / 'compressed' - remote = tmp_path / 'remote' - local = tmp_path / 'local' + compressed = tmp_path.joinpath('compressed') + remote = tmp_path.joinpath('remote') + local = tmp_path.joinpath('local') list(x.mkdir() for x in [compressed, remote, local]) return tuple(str(x) for x in [compressed, remote, local])