Skip to content

Commit

Permalink
refactor cloud creation and add support for gcp
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Apr 22, 2024
1 parent 4ce6081 commit c93f6c4
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions fog_x/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,38 @@ def convert_to_h264(input_file, output_file):
]
subprocess.run(command)

def create_cloud_bucket_if_not_exist(provider, bucket_name, dir_name):
logger.info(f"Creating bucket '{bucket_name}' in cloud provider '{provider}' with folder '{dir_name}'...")
if provider == "s3":
import boto3
s3_client = boto3.client('s3')
# s3_client.create_bucket(Bucket=bucket_name)
s3_client.put_object(Bucket=bucket_name, Key=f"{dir_name}/")
logger.info(f"Bucket '{bucket_name}' created in AWS S3.")
elif provider == "gs":
from google.cloud import storage
"""Create a folder in a Google Cloud Storage bucket if it does not exist."""
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)

# Ensure the folder name ends with a '/'
if not dir_name.endswith('/'):
dir_name += '/'

# Check if folder exists by trying to list objects with the folder prefix
blobs = storage_client.list_blobs(bucket_name, prefix=dir_name, delimiter='/')
exists = any(blob.name == dir_name for blob in blobs)

if not exists:
# Create an empty blob to simulate a folder
blob = bucket.blob(dir_name)
blob.upload_from_string('')
print(f"Folder '{dir_name}' created.")
else:
print(f"Folder '{dir_name}' already exists.")
else:
raise ValueError(f"Unsupported cloud provider '{provider}'.")

class Dataset:
"""
Create or load from a new dataset.
Expand Down Expand Up @@ -68,7 +100,7 @@ def __init__(
* is replace_existing actually used anywhere?
"""
self.name = name
path = os.path.expanduser(path)
path = os.path.expanduser(path).strip("/")
self.path = path
if path is None:
raise ValueError("Path is required")
Expand All @@ -90,21 +122,9 @@ def __init__(
except:
logger.info(f"Path does not exist. ({path}/{name})")
cloud_provider = path[:2]
if cloud_provider == "s3":
logger.info(f"Creating {cloud_provider} bucket...")
import boto3
s3_client = boto3.client('s3')
bucket_name = path[5:]
# s3_client.create_bucket(Bucket=bucket_name)
s3_client.put_object(Bucket=bucket_name, Key=f"{name}/")
logger.info(f"Bucket '{bucket_name}' created in AWS S3.")
# Reinitialize step_data_connector
step_data_connector = LazyFrameConnector(f"{path}/{name}")
elif cloud_provider == "gs":
logger.info(f"Creating {cloud_provider} bucket...")
pass
else:
logger.info(f"Unsupported cloud_provider {cloud_provider}.")
bucket_name = path[5:]
create_cloud_bucket_if_not_exist(cloud_provider, bucket_name, f"{name}/")
step_data_connector = LazyFrameConnector(f"{path}/{name}")
self.db_manager = DatabaseManager(episode_info_connector, step_data_connector)
self.db_manager.initialize_dataset(self.name, features)

Expand Down

0 comments on commit c93f6c4

Please sign in to comment.