Skip to content

Commit

Permalink
some refactor likely needed - don't think my abstractions are ideal f…
Browse files Browse the repository at this point in the history
…or what we're actually _doing_ regarding authenticating with aws, but it works locally w/ the impersonated SA
  • Loading branch information
GondekNP committed Jan 29, 2024
1 parent 1de2975 commit c2de8db
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 40 deletions.
21 changes: 1 addition & 20 deletions .deployment/tofu/modules/static_io/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -271,30 +271,11 @@ data "aws_iam_policy_document" "oidc_assume_role_policy" {

# Defines what actions can be done once the role is assumed.
data "aws_iam_policy_document" "session_policy" {
statement {
sid = "AllowListingOfUserFolder"
effect = "Allow"
actions = [
"s3:ListBucket",
]
resources = [
"arn:aws:s3:::burn-severity-backend",
]
condition {
test = "StringLike"
variable = "s3:prefix"
values = [
"/public/*",
"/public",
"/"
]
}
}

statement {
sid = "HomeDirObjectAccess"
effect = "Allow"
actions = [
"s3:ListBucket",
"s3:PutObject",
"s3:GetObject",
"s3:DeleteObject",
Expand Down
56 changes: 36 additions & 20 deletions src/util/cloud_static_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import smart_open
import time
import os
import json
import datetime
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(self, bucket_name, provider):
raise Exception(f"Provider {provider} not supported")

self.iam_credentials = None
self.role_assumed_credentials = None
self.s3_session = None
self.validate_credentials()

self.logger.log_text(f"Initialized CloudStaticIOClient for {self.bucket_name} with provider {provider}")
Expand Down Expand Up @@ -117,25 +120,34 @@ def fetch_id_token(self, audience):
return response.json()["token"]

def validate_credentials(self):
oidc_token = None
request = gcp_requests.Request()

if self.env == 'LOCAL':
if not self.iam_credentials or self.iam_credentials.expired:
self.impersonate_service_account()
self.iam_credentials.refresh(request)

oidc_token = self.fetch_id_token(audience="sts.amazonaws.com")
if not oidc_token:
raise ValueError("Failed to retrieve OIDC token")

sts_response = self.sts_client.assume_role_with_web_identity(
RoleArn=self.role_arn,
RoleSessionName=self.role_session_name,
WebIdentityToken=oidc_token
)

return sts_response['Credentials']
if not self.role_assumed_credentials or (self.role_assumed_credentials['Expiration'].timestamp() - time.now() < 300):
oidc_token = None
request = gcp_requests.Request()

if self.env == 'LOCAL':
if not self.iam_credentials or self.iam_credentials.expired:
self.impersonate_service_account()
self.iam_credentials.refresh(request)

oidc_token = self.fetch_id_token(audience="sts.amazonaws.com")
if not oidc_token:
raise ValueError("Failed to retrieve OIDC token")

sts_response = self.sts_client.assume_role_with_web_identity(
RoleArn=self.role_arn,
RoleSessionName=self.role_session_name,
WebIdentityToken=oidc_token
)

self.role_assumed_credentials = sts_response['Credentials']

self.boto_session = boto3.Session(
aws_access_key_id=self.role_assumed_credentials['AccessKeyId'],
aws_secret_access_key=self.role_assumed_credentials['SecretAccessKey'],
aws_session_token=self.role_assumed_credentials['SessionToken'],
region_name='us-east-2'
)

def download(self, remote_path, target_local_path):
"""
Expand All @@ -153,7 +165,9 @@ def download(self, remote_path, target_local_path):

# Download from remote s3 server to local
with smart_open.open(
f"{self.prefix}/{remote_path}"
f"{self.prefix}/{remote_path}",
"rb",
transport_params={"client": self.boto_session.client('s3')},
) as remote_file:
with open(target_local_path, "wb") as local_file:
local_file.write(remote_file.read())
Expand All @@ -173,7 +187,9 @@ def upload(self, source_local_path, remote_path):
# Upload file from local to S3
with open(source_local_path, "rb") as local_file:
with smart_open.open(
f"{self.prefix}/{remote_path}", "wb"
f"{self.prefix}/{remote_path}",
"wb",
transport_params={"client": self.boto_session.client('s3')},
) as remote_file:
remote_file.write(local_file.read())
print("upload completed")
Expand Down

0 comments on commit c2de8db

Please sign in to comment.