diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index ffac89ea98d53..8a1f3910746c9 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -51,9 +51,7 @@ BatchExportTemporaryFile, WriterFormat, ) -from posthog.temporal.batch_exports.utils import ( - set_status_to_running_task, -) +from posthog.temporal.batch_exports.utils import set_status_to_running_task from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -73,6 +71,8 @@ "InvalidS3Key", # All consumers failed with non-retryable errors. "RecordBatchConsumerNonRetryableExceptionGroup", + # Invalid S3 endpoint URL + "InvalidS3EndpointError", ] FILE_FORMAT_EXTENSIONS = { @@ -166,6 +166,13 @@ def __init__(self): super().__init__("Endpoint URL cannot be empty.") +class InvalidS3EndpointError(Exception): + """Exception raised when an S3 endpoint is invalid.""" + + def __init__(self, message: str = "Endpoint URL is invalid."): + super().__init__(message) + + Part = dict[str, str | int] @@ -240,14 +247,19 @@ def is_upload_in_progress(self) -> bool: async def s3_client(self): """Asynchronously yield an S3 client.""" - async with self._session.client( - "s3", - region_name=self.region_name, - aws_access_key_id=self.aws_access_key_id, - aws_secret_access_key=self.aws_secret_access_key, - endpoint_url=self.endpoint_url, - ) as client: - yield client + try: + async with self._session.client( + "s3", + region_name=self.region_name, + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + endpoint_url=self.endpoint_url, + ) as client: + yield client + except ValueError as err: + if "Invalid endpoint" in str(err): + raise InvalidS3EndpointError(str(err)) from err + raise async def start(self) -> str: """Start this S3MultiPartUpload.""" diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index dc0de17e53d58..76a3c20599518 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -29,6 +29,7 @@ from posthog.temporal.batch_exports.s3_batch_export import ( FILE_FORMAT_EXTENSIONS, IntermittentUploadPartTimeoutError, + InvalidS3EndpointError, S3BatchExportInputs, S3BatchExportWorkflow, S3HeartbeatDetails, @@ -40,9 +41,7 @@ ) from posthog.temporal.common.clickhouse import ClickHouseClient from posthog.temporal.tests.batch_exports.utils import mocked_start_batch_export_run -from posthog.temporal.tests.utils.events import ( - generate_test_events_in_clickhouse, -) +from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.tests.utils.models import ( acreate_batch_export, adelete_batch_export, @@ -1576,6 +1575,23 @@ async def client(self, *args, **kwargs): await s3_upload.upload_part(io.BytesIO(b"1010"), rewind=False) # type: ignore +async def test_s3_multi_part_upload_raises_exception_if_invalid_endpoint(bucket_name, s3_key_prefix): + """Test a InvalidS3EndpointError is raised if the endpoint is invalid.""" + s3_upload = S3MultiPartUpload( + bucket_name=bucket_name, + key=s3_key_prefix, + encryption=None, + kms_key_id=None, + region_name="us-east-1", + aws_access_key_id="object_storage_root_user", + aws_secret_access_key="object_storage_root_password", + endpoint_url="some-invalid-endpoint", + ) + + with pytest.raises(InvalidS3EndpointError): + await s3_upload.start() + + @pytest.mark.parametrize("model", [TEST_S3_MODELS[1], TEST_S3_MODELS[2], None]) async def test_s3_export_workflow_with_request_timeouts( clickhouse_client,