Skip to content

Commit

Permalink
minor tweaks to manifest logic to reflect location of manifest json
Browse files Browse the repository at this point in the history
  • Loading branch information
GondekNP committed Jan 29, 2024
1 parent c2de8db commit fdc7c5a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 60 deletions.
10 changes: 5 additions & 5 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def analyze_burn(

cloud_static_io_client.upload(
source_local_path=tmp_geojson,
remote_path=f"{affiliation}/{fire_event_name}/boundary.geojson",
remote_path=f"public/{affiliation}/{fire_event_name}/boundary.geojson",
)
cloud_static_io_client.disconnect()

Expand Down Expand Up @@ -335,7 +335,7 @@ def analyze_ecoclass(

cloud_static_io_client.upload(
source_local_path=tmp_geojson_path,
remote_path=f"{affiliation}/{fire_event_name}/ecoclass_dominant_cover.geojson",
remote_path=f"public/{affiliation}/{fire_event_name}/ecoclass_dominant_cover.geojson",
)

logger.log_text(f"Ecoclass GeoJSON uploaded for {fire_event_name}")
Expand Down Expand Up @@ -373,7 +373,7 @@ async def upload_shapefile(
# Upload the zip and a geojson to SFTP
cloud_static_io_client.upload(
source_local_path=tmp_zip,
remote_path=f"{affiliation}/{fire_event_name}/user_uploaded_{file.filename}",
remote_path=f"public/{affiliation}/{fire_event_name}/user_uploaded_{file.filename}",
)

with tempfile.NamedTemporaryFile(suffix=".geojson", delete=False) as tmp:
Expand All @@ -382,7 +382,7 @@ async def upload_shapefile(
f.write(geojson)
cloud_static_io_client.upload(
source_local_path=tmp_geojson,
remote_path=f"{affiliation}/{fire_event_name}/boundary.geojson",
remote_path=f"public/{affiliation}/{fire_event_name}/boundary.geojson",
)

return JSONResponse(status_code=200, content={"geojson": geojson})
Expand All @@ -405,7 +405,7 @@ async def upload_drawn_aoi(
f.write(geojson)
cloud_static_io_client.upload(
source_local_path=tmp_geojson,
remote_path=f"{affiliation}/{fire_event_name}/boundary.geojson",
remote_path=f"public/{affiliation}/{fire_event_name}/boundary.geojson",
)
return JSONResponse(status_code=200, content={"geojson": geojson})

Expand Down
61 changes: 6 additions & 55 deletions src/util/cloud_static_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,6 @@
from google.auth.transport import requests as gcp_requests
from google.auth import impersonated_credentials, exceptions

# TODO [#9]: Convert to agnostic Boto client
# Use the slick smart-open library to handle S3 connections. This maintains the agnostic nature
# of sftp, not tied to any specific cloud provider, but is way more efficient than paramiko/sftp in terms of $$

# def create_s3_client():
# try:
# # Get the OIDC token from your identity provider
# id_token = os.environ.get('OIDC_TOKEN')

# # Create a new STS client
# sts_client = boto3.client('sts')

# # Assume the role with web identity
# assumed_role_object = sts_client.assume_role_with_web_identity(
# RoleArn="arn:aws:iam::account-of-the-iam-role:role/name-of-the-iam-role",
# RoleSessionName="AssumeRoleSession1",
# WebIdentityToken=id_token
# )

# # Extract the credentials
# credentials = assumed_role_object['Credentials']

# # Create a new session with the temporary credentials
# session = boto3.Session(
# aws_access_key_id=credentials['AccessKeyId'],
# aws_secret_access_key=credentials['SecretAccessKey'],
# aws_session_token=credentials['SessionToken'],
# )

# return session.client('s3')

# except (BotoCoreError, NoCredentialsError) as error:
# print(error)
# return None

class CloudStaticIOClient:
def __init__(self, bucket_name, provider):

Expand All @@ -66,11 +31,10 @@ def __init__(self, bucket_name, provider):
log_name = "burn-backend"
self.logger = logging_client.logger(log_name)

boto3.set_stream_logger('')
self.sts_client = boto3.client('sts')

if provider == "s3":
self.prefix = f"s3://{self.bucket_name}/public"
self.prefix = f"s3://{self.bucket_name}"
else:
raise Exception(f"Provider {provider} not supported")

Expand Down Expand Up @@ -121,7 +85,7 @@ def fetch_id_token(self, audience):

def validate_credentials(self):

if not self.role_assumed_credentials or (self.role_assumed_credentials['Expiration'].timestamp() - time.now() < 300):
if not self.role_assumed_credentials or (self.role_assumed_credentials['Expiration'].timestamp() - time.time() < 300):
oidc_token = None
request = gcp_requests.Request()

Expand Down Expand Up @@ -154,6 +118,7 @@ def download(self, remote_path, target_local_path):
Downloads the file from remote s3 server to local.
Also, by default extracts the file to the specified target_local_path
"""
self.validate_credentials()
try:
# Create the target directory if it does not exist
path, _ = os.path.split(target_local_path)
Expand All @@ -179,6 +144,7 @@ def upload(self, source_local_path, remote_path):
"""
Uploads the source files from local to the s3 server.
"""
self.validate_credentials()
try:
print(
f"uploading to {self.bucket_name} [(remote path: {remote_path});(source local path: {source_local_path})]"
Expand Down Expand Up @@ -207,21 +173,6 @@ def listdir_attr(self, remote_path):
for attr in self.connection.listdir_attr(remote_path):
yield attr

# def get_available_cogs(self):
# """Lists all available COGs on the SFTP server"""
# available_cogs = {}
# for top_level_folder in self.connection.listdir():
# if not top_level_folder.endswith(".json"):
# s3_file_path = f"{top_level_folder}/metrics.tif"
# available_cogs[top_level_folder] = s3_file_path

# return available_cogs

# def update_available_cogs(self):
# self.connect()
# self.available_cogs = self.get_available_cogs()
# self.disconnect()

def upload_cogs(
self,
metrics_stack,
Expand All @@ -245,7 +196,7 @@ def upload_cogs(

self.upload(
source_local_path=local_cog_path,
remote_path=f"{affiliation}/{fire_event_name}/{band_name}.tif",
remote_path=f"public/{affiliation}/{fire_event_name}/{band_name}.tif",
)

# Upload the difference between dNBR and RBR
Expand All @@ -261,7 +212,7 @@ def upload_cogs(
pct_change.rio.to_raster(local_cog_path, driver="GTiff")
self.upload(
source_local_path=local_cog_path,
remote_path=f"{affiliation}/{fire_event_name}/pct_change_dnbr_rbr.tif",
remote_path=f"public/{affiliation}/{fire_event_name}/pct_change_dnbr_rbr.tif",
)

def update_manifest(
Expand Down

0 comments on commit fdc7c5a

Please sign in to comment.