Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Autodetect AWS region during deltalake scan #3104

Merged
merged 7 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions daft/delta_lake/delta_lake_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ScanTask,
StorageConfig,
)
from daft.io.aws_config import boto3_client_from_s3_config
from daft.io.object_store_options import io_config_to_storage_options
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema
Expand All @@ -43,6 +44,19 @@
deltalake_sdk_io_config = storage_config.config.io_config
scheme = urlparse(table_uri).scheme
if scheme == "s3" or scheme == "s3a":
# Try to get region from boto3
if deltalake_sdk_io_config.s3.region_name is None:
try:
client = boto3_client_from_s3_config("s3", deltalake_sdk_io_config.s3)
response = client.get_bucket_location(Bucket=urlparse(table_uri).netloc)
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Boto3 superclass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, at least log if we're going to move ahead silently

pass

Check warning on line 53 in daft/delta_lake/delta_lake_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/delta_lake/delta_lake_scan.py#L49-L53

Added lines #L49 - L53 were not covered by tests
else:
deltalake_sdk_io_config = deltalake_sdk_io_config.replace(

Check warning on line 55 in daft/delta_lake/delta_lake_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/delta_lake/delta_lake_scan.py#L55

Added line #L55 was not covered by tests
s3=deltalake_sdk_io_config.s3.replace(region_name=response["LocationConstraint"])
)

# Try to get config from the environment
if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]):
try:
s3_config_from_env = S3Config.from_env()
Expand Down
21 changes: 21 additions & 0 deletions daft/io/aws_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import TYPE_CHECKING

from daft.daft import S3Config

if TYPE_CHECKING:
import boto3

Check warning on line 6 in daft/io/aws_config.py

View check run for this annotation

Codecov / codecov/patch

daft/io/aws_config.py#L6

Added line #L6 was not covered by tests


def boto3_client_from_s3_config(service: str, s3_config: S3Config) -> "boto3.client":
import boto3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to amend our requirements to make boto3 a requirement when using deltalake then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically you can use it with only Azure which is why I decided not to list it as an explicit requirement. Thoughts on that?


return boto3.client(
service,
region_name=s3_config.region_name,
use_ssl=s3_config.use_ssl,
verify=s3_config.verify_ssl,
endpoint_url=s3_config.endpoint_url,
aws_access_key_id=s3_config.key_id,
aws_secret_access_key=s3_config.access_key,
aws_session_token=s3_config.session_token,
)
15 changes: 2 additions & 13 deletions daft/io/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

from daft.daft import IOConfig
from daft.io.aws_config import boto3_client_from_s3_config


class DataCatalogType(Enum):
Expand Down Expand Up @@ -42,20 +43,8 @@ def table_uri(self, io_config: IOConfig) -> str:
"""
if self.catalog == DataCatalogType.GLUE:
# Use boto3 to get the table from AWS Glue Data Catalog.
import boto3
glue = boto3_client_from_s3_config("glue", io_config.s3)

s3_config = io_config.s3

glue = boto3.client(
"glue",
region_name=s3_config.region_name,
use_ssl=s3_config.use_ssl,
verify=s3_config.verify_ssl,
endpoint_url=s3_config.endpoint_url,
aws_access_key_id=s3_config.key_id,
aws_secret_access_key=s3_config.access_key,
aws_session_token=s3_config.session_token,
)
if self.catalog_id is not None:
# Allow cross account access, table.catalog_id should be the target account id
glue_table = glue.get_table(
Expand Down
Loading