Skip to content

Commit

Permalink
Add Beaker.dataset.stream_file() method
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Apr 8, 2022
1 parent 6e3e8fc commit af423e6
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `Beaker.dataset.commit()`.
- Added `Beaker.dataset.ls()`.
- Added `Beaker.dataset.stream_file()`.

### Changed

Expand Down
12 changes: 10 additions & 2 deletions beaker/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,10 +654,18 @@ def _validate_datetime(cls, v: Optional[datetime]) -> Optional[datetime]:

class FileInfo(BaseModel):
path: str
size: int
digest: str
updated: datetime
url: str

size: Optional[int] = None
"""
The size of the file, if known.
"""

url: Optional[str] = None
"""
A URL that can be used to directly download the file.
"""


class DatasetManifest(BaseModel):
Expand Down
126 changes: 108 additions & 18 deletions beaker/services/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from datetime import datetime
from pathlib import Path
from typing import Deque, Dict, Generator, List, Optional, Tuple, Union

Expand All @@ -24,6 +25,7 @@ class DatasetClient(ServiceClient):
HEADER_UPLOAD_LENGTH = "Upload-Length"
HEADER_UPLOAD_OFFSET = "Upload-Offset"
HEADER_DIGEST = "Digest"
HEADER_LAST_MODIFIED = "Last-Modified"

SHA256 = "SHA256"

Expand Down Expand Up @@ -183,13 +185,19 @@ def fetch(
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
download_futures = []
for file_info in self._iter_files(dataset.storage):
assert file_info.size is not None
if total_bytes_to_download is None:
progress.update(bytes_task, total=total_downloaded + file_info.size + 1)
target_path = target / Path(file_info.path)
if not force and target_path.exists():
raise FileExistsError(file_info.path)
future = executor.submit(
self._download_file, file_info, target_path, progress, bytes_task
self._download_file,
dataset.storage,
file_info,
target_path,
progress,
bytes_task,
)
download_futures.append(future)

Expand All @@ -199,6 +207,49 @@ def fetch(
if total_bytes_to_download is None:
progress.update(bytes_task, total=total_downloaded, completed=total_downloaded)

def stream_file(
self,
dataset: Union[str, Dataset],
file_name: str,
offset: int = 0,
length: int = -1,
max_retries: int = 5,
) -> Generator[bytes, None, None]:
"""
Stream download the contents of a single file from a dataset.
:param dataset: The dataset ID, full name, or object.
:param file_name: The path of the file within the dataset.
:param offset: Offset to start from, in bytes.
:param length: Number of bytes to read.
:param max_retries: Number of times to restart the download when HTTP errors occur.
Errors can be expected for very large files.
:raises DatasetNotFound: If the dataset can't be found.
:raises FileNotFoundError: If the file doesn't exist in the dataset.
:raises HTTPError: Any other HTTP exception that can occur.
"""
if not isinstance(dataset, Dataset) or dataset.storage is None:
dataset = self.get(dataset.id if isinstance(dataset, Dataset) else dataset)
assert dataset.storage is not None
response = self.request(
f"datasets/{dataset.storage.id}/files/{file_name}",
method="HEAD",
token=dataset.storage.token,
base_url=dataset.storage.address,
exceptions_for_status={404: FileNotFoundError(file_name)},
)
file_info = FileInfo(
path=file_name,
digest=response.headers[self.HEADER_DIGEST],
updated=datetime.strptime(
response.headers[self.HEADER_LAST_MODIFIED], "%a, %d %b %Y %H:%M:%S %Z"
),
)
yield from self._stream_file(
dataset.storage, file_info, offset=offset, length=length, max_retries=max_retries
)

def get(self, dataset: str) -> Dataset:
"""
Get info about a dataset.
Expand Down Expand Up @@ -456,28 +507,67 @@ def _iter_files(self, storage: DatasetStorage) -> Generator[FileInfo, None, None
if not cursor:
last_request = True

def _stream_file(
self,
storage: DatasetStorage,
file_info: FileInfo,
chunk_size: int = 1,
offset: int = 0,
length: int = -1,
max_retries: int = 5,
) -> Generator[bytes, None, None]:
def stream_file(offset: int, length: int) -> Generator[bytes, None, None]:
headers = {}
if offset > 0 and length > 0:
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
elif offset > 0:
headers["Range"] = f"bytes={offset}-"
response = self.request(
f"datasets/{storage.id}/files/{file_info.path}",
method="GET",
stream=True,
headers=headers,
token=storage.token,
base_url=storage.address,
exceptions_for_status={404: FileNotFoundError(file_info.path)},
)
for chunk in response.iter_content(chunk_size=chunk_size):
yield chunk

retries = 0
while True:
try:
for chunk in stream_file(offset, length):
offset += len(chunk)
yield chunk
break
except HTTPError:
if retries >= max_retries:
raise
retries += 1

def _download_file(
self, file_info: FileInfo, target_path: Path, progress: Progress, task_id: TaskID
self,
storage: DatasetStorage,
file_info: FileInfo,
target_path: Path,
progress: Progress,
task_id: TaskID,
) -> int:
import tempfile

total_bytes = 0
target_dir = target_path.parent
target_dir.mkdir(exist_ok=True, parents=True)
with self._session_with_backoff() as session:
response = session.get(file_info.url, stream=True)
response.raise_for_status()
tmp_target = tempfile.NamedTemporaryFile(
"w+b", dir=target_dir, delete=False, suffix=".tmp"
)
try:
for chunk in response.iter_content(chunk_size=1024):
total_bytes += len(chunk)
tmp_target.write(chunk)
progress.update(task_id, advance=len(chunk))
os.replace(tmp_target.name, target_path)
finally:
tmp_target.close()
if os.path.exists(tmp_target.name):
os.remove(tmp_target.name)
tmp_target = tempfile.NamedTemporaryFile("w+b", dir=target_dir, delete=False, suffix=".tmp")
try:
for chunk in self._stream_file(storage, file_info):
total_bytes += len(chunk)
tmp_target.write(chunk)
progress.update(task_id, advance=len(chunk))
os.replace(tmp_target.name, target_path)
finally:
tmp_target.close()
if os.path.exists(tmp_target.name):
os.remove(tmp_target.name)
return total_bytes
2 changes: 1 addition & 1 deletion beaker/services/service_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def request(
method: str = "GET",
query: Optional[Dict[str, Any]] = None,
data: Optional[Any] = None,
exceptions_for_status: Optional[Dict[int, BeakerError]] = None,
exceptions_for_status: Optional[Dict[int, Exception]] = None,
headers: Optional[Dict[str, str]] = None,
token: Optional[str] = None,
base_url: Optional[str] = None,
Expand Down
13 changes: 13 additions & 0 deletions tests/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ def test_dataset_write_error(self, client: Beaker, dataset_name: str):
with pytest.raises(DatasetWriteError):
client.dataset.sync(dataset, self.file_b.name)

def test_stream_file(self, client: Beaker, dataset_name: str):
dataset = client.dataset.create(dataset_name, self.file_a.name, commit=True)

# Stream the whole thing at once.
contents = b"".join(list(client.dataset.stream_file(dataset, Path(self.file_a.name).name)))
assert contents == self.file_a_contents

# Stream just part of the file.
contents = b"".join(
list(client.dataset.stream_file(dataset, Path(self.file_a.name).name, offset=5))
)
assert contents == self.file_a_contents[5:]


class TestLargeFileDataset:
def setup_method(self):
Expand Down

0 comments on commit af423e6

Please sign in to comment.