Skip to content

Commit

Permalink
updates to data_sync functionality (fail_if_missing flag)
Browse files Browse the repository at this point in the history
  • Loading branch information
rpmcginty committed Jun 21, 2024
1 parent bfbfcb5 commit e484bce
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 8 deletions.
8 changes: 7 additions & 1 deletion src/aibs_informatics_aws_utils/data_sync/file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ def partition(
size_bytes_exceeding_obj_nodes = []

partitioned_nodes: List[Node] = []
logger.info(
f"Partitioning nodes with size_bytes_limit={size_bytes_limit} "
f"and object_count_limit={object_count_limit}"
)

while unchecked_nodes:
unchecked_node = unchecked_nodes.pop()
if (size_bytes_limit and unchecked_node.size_bytes > size_bytes_limit) or (
Expand All @@ -219,6 +224,7 @@ def partition(
raise ValueError(msg)
logger.warning(msg)
partitioned_nodes.extend(size_bytes_exceeding_obj_nodes)
logger.info(f"Partitioned {len(partitioned_nodes)} nodes.")
return partitioned_nodes

@classmethod
Expand Down Expand Up @@ -326,7 +332,7 @@ def from_path(cls, path: str, **kwargs) -> S3FileSystem:
return s3_root


def get_file_system(path: Optional[Union[str, Path]]) -> BaseFileSystem:
def get_file_system(path: Union[str, Path]) -> BaseFileSystem:
if isinstance(path, str) and S3URI.is_valid(path):
return S3FileSystem.from_path(path)
elif isinstance(path, str) and EFSPath.is_valid(path):
Expand Down
38 changes: 37 additions & 1 deletion src/aibs_informatics_aws_utils/data_sync/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from aibs_informatics_core.utils.os_operations import find_all_paths

from aibs_informatics_aws_utils.efs import get_local_path
from aibs_informatics_aws_utils.s3 import Config, TransferConfig, delete_s3_path, sync_paths
from aibs_informatics_aws_utils.s3 import (
Config,
TransferConfig,
delete_s3_path,
is_folder,
is_object,
sync_paths,
)

logger = get_logger(__name__)

Expand All @@ -45,6 +52,11 @@ def botocore_config(self) -> Config:

def sync_local_to_s3(self, source_path: LocalPath, destination_path: S3URI):
source_path = self.sanitize_local_path(source_path)
if not source_path.exists():
if self.config.fail_if_missing:
raise FileNotFoundError(f"Local path {source_path} does not exist")
self.logger.warning(f"Local path {source_path} does not exist")
return
if source_path.is_dir():
self.logger.info("local source path is folder. Adding suffix to destination path")
destination_path = S3URI.build(
Expand All @@ -68,6 +80,13 @@ def sync_s3_to_local(self, source_path: S3URI, destination_path: LocalPath):
self.logger.info(f"Downloading s3 content from {source_path} -> {destination_path}")
start_time = datetime.now(tz=timezone.utc)

if not is_object(source_path) and not is_folder(source_path):
message = f"S3 path {source_path} does not exist as object or folder"
if self.config.fail_if_missing:
raise FileNotFoundError(message)
self.logger.warning(message)
return

_sync_paths = sync_paths

if self.config.require_lock:
Expand Down Expand Up @@ -113,6 +132,13 @@ def sync_local_to_local(self, source_path: LocalPath, destination_path: LocalPat
destination_path = self.sanitize_local_path(destination_path)
self.logger.info(f"Copying local content from {source_path} -> {destination_path}")
start_time = datetime.now(tz=timezone.utc)

if not source_path.exists():
if self.config.fail_if_missing:
raise FileNotFoundError(f"Local path {source_path} does not exist")
self.logger.warning(f"Local path {source_path} does not exist")
return

if self.config.retain_source_data:
copy_path(source_path=source_path, destination_path=destination_path, exists_ok=True)
else:
Expand All @@ -127,6 +153,14 @@ def sync_s3_to_s3(
source_path_prefix: Optional[S3KeyPrefix] = None,
):
self.logger.info(f"Syncing s3 content from {source_path} -> {destination_path}")

if not is_object(source_path) and not is_folder(source_path):
message = f"S3 path {source_path} does not exist as object or folder"
if self.config.fail_if_missing:
raise FileNotFoundError(message)
self.logger.warning(message)
return

sync_paths(
source_path=source_path,
destination_path=destination_path,
Expand Down Expand Up @@ -200,6 +234,7 @@ def sync_data(
require_lock: bool = False,
force: bool = False,
size_only: bool = False,
fail_if_missing: bool = True,
):
request = DataSyncRequest(
source_path=source_path,
Expand All @@ -210,6 +245,7 @@ def sync_data(
require_lock=require_lock,
force=force,
size_only=size_only,
fail_if_missing=fail_if_missing,
)
return DataSyncOperations.sync_request(request=request)

Expand Down
72 changes: 66 additions & 6 deletions test/aibs_informatics_aws_utils/data_sync/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
from test.aibs_informatics_aws_utils.base import AwsBaseTest
from typing import Union
from typing import Optional, Union

import moto
from aibs_informatics_core.models.aws.s3 import S3URI
Expand All @@ -25,20 +25,22 @@ def setUpLocalFS(self) -> Path:
fs = self.tmp_path()
return fs

def setUpBucket(self, bucket_name: str = None) -> str:
def setUpBucket(self, bucket_name: Optional[str] = None) -> str:
bucket_name = bucket_name or self.DEFAULT_BUCKET_NAME
self.s3_client.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": self.DEFAULT_REGION},
CreateBucketConfiguration={"LocationConstraint": self.DEFAULT_REGION}, # type: ignore
)
return bucket_name

def put_object(self, key: str, content: str, bucket_name: str = None, **kwargs) -> S3URI:
def put_object(
self, key: str, content: str, bucket_name: Optional[str] = None, **kwargs
) -> S3URI:
bucket_name = bucket_name or self.DEFAULT_BUCKET_NAME
self.s3_client.put_object(Bucket=bucket_name, Key=key, Body=content, **kwargs)
return self.get_s3_path(key=key, bucket_name=bucket_name)

def get_object(self, key: str, bucket_name: str = None) -> str:
def get_object(self, key: str, bucket_name: Optional[str] = None) -> str:
bucket_name = bucket_name or self.DEFAULT_BUCKET_NAME
response = self.s3_client.get_object(Bucket=bucket_name, Key=key)
return response["Body"].read().decode()
Expand All @@ -59,7 +61,7 @@ def s3_client(self):
def s3_resource(self):
return get_s3_resource(region=self.DEFAULT_REGION)

def get_s3_path(self, key: str, bucket_name: str = None) -> S3URI:
def get_s3_path(self, key: str, bucket_name: Optional[str] = None) -> S3URI:
bucket_name = bucket_name or self.DEFAULT_BUCKET_NAME
return S3URI.build(bucket_name=bucket_name, key=key)

Expand Down Expand Up @@ -102,6 +104,20 @@ def test__sync_data__s3_to_s3__file__succeeds__source_deleted(self):
assert self.get_object(destination_path.key) == "hello"
assert not is_object(source_path)

def test__sync_data__s3_to_s3__file__does_not_exist(self):
self.setUpBucket()
source_path = self.get_s3_path("source")
destination_path = self.get_s3_path("destination")
with self.assertRaises(FileNotFoundError):
sync_data(
source_path=source_path,
destination_path=destination_path,
)
sync_data(
source_path=source_path, destination_path=destination_path, fail_if_missing=False
)
assert not is_object(destination_path)

def test__sync_data__local_to_local__folder__succeeds(self):
fs = self.setUpLocalFS()
source_path = fs / "source"
Expand Down Expand Up @@ -153,6 +169,20 @@ def test__sync_data__local_to_local__file__source_deleted(self):
assert destination_path.read_text() == "hello"
assert not source_path.exists()

def test__sync_data__local_to_local__file__does_not_exist(self):
fs = self.setUpLocalFS()
source_path = fs / "source"
destination_path = fs / "destination"
with self.assertRaises(FileNotFoundError):
sync_data(
source_path=source_path,
destination_path=destination_path,
)
sync_data(
source_path=source_path, destination_path=destination_path, fail_if_missing=False
)
assert not destination_path.exists()

def test__sync_data__s3_to_local__folder__succeeds(self):
fs = self.setUpLocalFS()
self.setUpBucket()
Expand Down Expand Up @@ -223,6 +253,21 @@ def test__sync_data__s3_to_local__file__source_not_deleted_despite_flag(self):
)
self.assertPathsEqual(source_path, destination_path, 1)

def test__sync_data__s3_to_local__file__does_not_exist(self):
fs = self.setUpLocalFS()
self.setUpBucket()
source_path = self.get_s3_path("source")
destination_path = fs / "destination"
with self.assertRaises(FileNotFoundError):
sync_data(
source_path=source_path,
destination_path=destination_path,
)
sync_data(
source_path=source_path, destination_path=destination_path, fail_if_missing=False
)
assert not destination_path.exists()

def test__sync_data__local_to_s3__folder__succeeds(self):
fs = self.setUpLocalFS()
self.setUpBucket()
Expand Down Expand Up @@ -264,6 +309,21 @@ def test__sync_data__local_to_s3__file__source_deleted(self):
)
assert not source_path.exists()

def test__sync_data__local_to_s3__file__does_not_exist(self):
fs = self.setUpLocalFS()
self.setUpBucket()
source_path = fs / "source"
destination_path = self.get_s3_path("destination")
with self.assertRaises(FileNotFoundError):
sync_data(
source_path=source_path,
destination_path=destination_path,
)
sync_data(
source_path=source_path, destination_path=destination_path, fail_if_missing=False
)
assert not is_object(destination_path)

def assertPathsEqual(
self, src_path: Union[Path, S3URI], dst_path: Union[Path, S3URI], expected_num_files: int
):
Expand Down

0 comments on commit e484bce

Please sign in to comment.