From c93f6c4f0da83fce1f5f2bc30076d797966ec5d1 Mon Sep 17 00:00:00 2001 From: Eric Chen Date: Mon, 22 Apr 2024 07:27:34 +0000 Subject: [PATCH] refactor cloud creation and add support for gcp --- fog_x/dataset.py | 52 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/fog_x/dataset.py b/fog_x/dataset.py index 07f25fa..eb99de0 100644 --- a/fog_x/dataset.py +++ b/fog_x/dataset.py @@ -33,6 +33,38 @@ def convert_to_h264(input_file, output_file): ] subprocess.run(command) +def create_cloud_bucket_if_not_exist(provider, bucket_name, dir_name): + logger.info(f"Creating bucket '{bucket_name}' in cloud provider '{provider}' with folder '{dir_name}'...") + if provider == "s3": + import boto3 + s3_client = boto3.client('s3') + # s3_client.create_bucket(Bucket=bucket_name) + s3_client.put_object(Bucket=bucket_name, Key=f"{dir_name}/") + logger.info(f"Bucket '{bucket_name}' created in AWS S3.") + elif provider == "gs": + from google.cloud import storage + """Create a folder in a Google Cloud Storage bucket if it does not exist.""" + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + + # Ensure the folder name ends with a '/' + if not dir_name.endswith('/'): + dir_name += '/' + + # Check if folder exists by trying to list objects with the folder prefix + blobs = storage_client.list_blobs(bucket_name, prefix=dir_name, delimiter='/') + exists = any(blob.name == dir_name for blob in blobs) + + if not exists: + # Create an empty blob to simulate a folder + blob = bucket.blob(dir_name) + blob.upload_from_string('') + print(f"Folder '{dir_name}' created.") + else: + print(f"Folder '{dir_name}' already exists.") + else: + raise ValueError(f"Unsupported cloud provider '{provider}'.") + class Dataset: """ Create or load from a new dataset. @@ -68,7 +100,7 @@ def __init__( * is replace_existing actually used anywhere? """ self.name = name - path = os.path.expanduser(path) + path = os.path.expanduser(path).strip("/") self.path = path if path is None: raise ValueError("Path is required") @@ -90,21 +122,9 @@ def __init__( except: logger.info(f"Path does not exist. ({path}/{name})") cloud_provider = path[:2] - if cloud_provider == "s3": - logger.info(f"Creating {cloud_provider} bucket...") - import boto3 - s3_client = boto3.client('s3') - bucket_name = path[5:] - # s3_client.create_bucket(Bucket=bucket_name) - s3_client.put_object(Bucket=bucket_name, Key=f"{name}/") - logger.info(f"Bucket '{bucket_name}' created in AWS S3.") - # Reinitialize step_data_connector - step_data_connector = LazyFrameConnector(f"{path}/{name}") - elif cloud_provider == "gs": - logger.info(f"Creating {cloud_provider} bucket...") - pass - else: - logger.info(f"Unsupported cloud_provider {cloud_provider}.") + bucket_name = path[5:] + create_cloud_bucket_if_not_exist(cloud_provider, bucket_name, f"{name}/") + step_data_connector = LazyFrameConnector(f"{path}/{name}") self.db_manager = DatabaseManager(episode_info_connector, step_data_connector) self.db_manager.initialize_dataset(self.name, features)