Skip to content

Commit

Permalink
refactor: Refactor snowflake to use spmc abstractions
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Dec 13, 2024
1 parent 646de76 commit b9795bd
Showing 1 changed file with 136 additions and 133 deletions.
269 changes: 136 additions & 133 deletions posthog/temporal/batch_exports/snowflake_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import snowflake.connector
from django.conf import settings
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import OperationalError, InterfaceError
from snowflake.connector.errors import InterfaceError, OperationalError
from temporalio import activity, workflow
from temporalio.common import RetryPolicy

Expand All @@ -31,32 +31,43 @@
default_fields,
execute_batch_export_insert_activity,
get_data_interval,
iter_model_records,
start_batch_export_run,
)
from posthog.temporal.batch_exports.metrics import (
get_bytes_exported_metric,
get_rows_exported_metric,
from posthog.temporal.batch_exports.heartbeat import (
BatchExportRangeHeartbeatDetails,
DateRange,
should_resume_from_activity_heartbeat,
)
from posthog.temporal.batch_exports.spmc import (
Consumer,
Producer,
RecordBatchQueue,
run_consumer_loop,
wait_for_schema_or_producer,
)
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
JSONLBatchExportWriter,
WriterFormat,
)
from posthog.temporal.batch_exports.utils import (
JsonType,
apeek_first_and_rewind,
cast_record_batch_json_columns,
set_status_to_running_task,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.batch_exports.heartbeat import (
BatchExportRangeHeartbeatDetails,
DateRange,
HeartbeatParseError,
should_resume_from_activity_heartbeat,
)

NON_RETRYABLE_ERROR_TYPES = [
# Raised when we cannot connect to Snowflake.
"DatabaseError",
# Raised by Snowflake when a query cannot be compiled.
# Usually this means we don't have table permissions or something doesn't exist (db, schema).
"ProgrammingError",
# Raised by Snowflake with an incorrect account name.
"ForbiddenError",
# Our own exception when we can't connect to Snowflake, usually due to invalid parameters.
"SnowflakeConnectionError",
]


class SnowflakeFileNotUploadedError(Exception):
Expand Down Expand Up @@ -91,37 +102,9 @@ class SnowflakeRetryableConnectionError(Exception):

@dataclasses.dataclass
class SnowflakeHeartbeatDetails(BatchExportRangeHeartbeatDetails):
"""The Snowflake batch export details included in every heartbeat.
Attributes:
file_no: The file number of the last file we managed to upload.
"""

file_no: int = 0

@classmethod
def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]:
"""Attempt to initialize HeartbeatDetails from an activity's details."""
file_no = 0
remaining = super().deserialize_details(details)

if len(remaining["_remaining"]) == 0:
return {"file_no": 0, **remaining}

first_detail = remaining["_remaining"][0]
remaining["_remaining"] = remaining["_remaining"][1:]

try:
file_no = int(first_detail)
except (TypeError, ValueError) as e:
raise HeartbeatParseError("file_no") from e

return {"file_no": file_no, **remaining}
"""The Snowflake batch export details included in every heartbeat."""

def serialize_details(self) -> tuple[typing.Any, ...]:
"""Attempt to initialize HeartbeatDetails from an activity's details."""
serialized_parent_details = super().serialize_details()
return (*serialized_parent_details[:-1], self.file_no, self._remaining)
pass


@dataclasses.dataclass
Expand Down Expand Up @@ -344,22 +327,16 @@ async def put_file_to_snowflake_table(
file: BatchExportTemporaryFile,
table_stage_prefix: str,
table_name: str,
file_no: int,
):
"""Executes a PUT query using the provided cursor to the provided table_name.
Sadly, Snowflake's execute_async does not work with PUT statements. So, we pass the execute
call to run_in_executor: Since execute ends up boiling down to blocking IO (HTTP request),
the event loop should not be locked up.
We add a file_no to the file_name when executing PUT as Snowflake will reject any files with the same
name. Since batch exports re-use the same file, our name does not change, but we don't want Snowflake
to reject or overwrite our new data.
Args:
file: The name of the local file to PUT.
table_name: The name of the Snowflake table where to PUT the file.
file_no: An int to identify which file number this is.
Raises:
TypeError: If we don't get a tuple back from Snowflake (should never happen).
Expand All @@ -371,7 +348,7 @@ async def put_file_to_snowflake_table(
# So we ask mypy to be nice with us.
reader = io.BufferedReader(file) # type: ignore
query = f"""
PUT file://{file.name}_{file_no}.jsonl '@%"{table_name}"/{table_stage_prefix}'
PUT file://{file.name} '@%"{table_name}"/{table_stage_prefix}'
"""

with self.connection.cursor() as cursor:
Expand Down Expand Up @@ -518,6 +495,52 @@ def snowflake_default_fields() -> list[BatchExportField]:
return batch_export_fields


class SnowflakeConsumer(Consumer):
def __init__(
self,
heartbeater: Heartbeater,
heartbeat_details: SnowflakeHeartbeatDetails,
data_interval_start: dt.datetime | str | None,
snowflake_client: SnowflakeClient,
snowflake_table: str,
snowflake_table_stage_prefix: str,
):
super().__init__(heartbeater, heartbeat_details, data_interval_start)
self.heartbeat_details: SnowflakeHeartbeatDetails = heartbeat_details
self.snowflake_table = snowflake_table
self.snowflake_client = snowflake_client
self.snowflake_table_stage_prefix = snowflake_table_stage_prefix

async def flush(
self,
batch_export_file: BatchExportTemporaryFile,
records_since_last_flush: int,
bytes_since_last_flush: int,
flush_counter: int,
last_date_range: DateRange,
is_last: bool,
error: Exception | None,
):
await self.logger.ainfo(
"Putting file %s containing %s records with size %s bytes",
flush_counter,
records_since_last_flush,
bytes_since_last_flush,
)

await self.snowflake_client.put_file_to_snowflake_table(
batch_export_file,
self.snowflake_table_stage_prefix,
self.snowflake_table,
)

await self.logger.adebug("Loaded %s to Snowflake table '%s'", records_since_last_flush, self.snowflake_table)
self.rows_exported_counter.add(records_since_last_flush)
self.bytes_exported_counter.add(bytes_since_last_flush)

self.heartbeat_details.track_done_range(last_date_range, self.data_interval_start)


def get_snowflake_fields_from_record_schema(
record_schema: pa.Schema, known_variant_columns: list[str]
) -> list[SnowflakeField]:
Expand Down Expand Up @@ -594,42 +617,63 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
details = SnowflakeHeartbeatDetails()

done_ranges: list[DateRange] = details.done_ranges
if done_ranges:
data_interval_start: str | None = done_ranges[-1][1].isoformat()
else:
data_interval_start = inputs.data_interval_start

current_flush_counter = details.file_no

rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()

model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
field.name for field in dataclasses.fields(inputs)
}:
model = inputs.batch_export_model
if model is not None:
model_name = model.name
extra_query_parameters = model.schema["values"] if model.schema is not None else None
fields = model.schema["fields"] if model.schema is not None else None
else:
model_name = "events"
extra_query_parameters = None
fields = None
else:
model = inputs.batch_export_schema
model_name = "custom"
extra_query_parameters = model["values"] if model is not None else {}
fields = model["fields"] if model is not None else None

records_iterator = iter_model_records(
client=client,
model=model,
data_interval_start = (
dt.datetime.fromisoformat(inputs.data_interval_start) if inputs.data_interval_start else None
)
data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end)
full_range = (data_interval_start, data_interval_end)

queue = RecordBatchQueue()
producer = Producer(clickhouse_client=client)
producer_task = producer.start(
queue=queue,
model_name=model_name,
is_backfill=inputs.is_backfill,
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
full_range=full_range,
done_ranges=done_ranges,
fields=fields,
destination_default_fields=snowflake_default_fields(),
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
destination_default_fields=snowflake_default_fields(),
is_backfill=inputs.is_backfill,
extra_query_parameters=extra_query_parameters,
)
records_completed = 0

record_batch_schema = await wait_for_schema_or_producer(queue, producer_task)
if record_batch_schema is None:
return records_completed

record_batch_schema = pa.schema(
# NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other
# record batches have them as nullable.
# Until we figure it out, we set all fields to nullable. There are some fields we know
# are not nullable, but I'm opting for the more flexible option until we out why schemas differ
# between batches.
[field.with_nullable(True) for field in record_batch_schema if field.name != "_inserted_at"]
)
first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator)

if first_record_batch is None:
return 0

known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"]
first_record_batch = cast_record_batch_json_columns(first_record_batch, json_columns=known_variant_columns)

if model is None or (isinstance(model, BatchExportModel) and model.name == "events"):
table_fields = [
Expand All @@ -647,10 +691,8 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
]

else:
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
table_fields = get_snowflake_fields_from_record_schema(
record_schema,
record_batch_schema,
known_variant_columns=known_variant_columns,
)

Expand All @@ -671,57 +713,28 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
stagle_table_name, data_interval_end_str, table_fields, create=requires_merge, delete=requires_merge
) as snow_stage_table,
):
record_columns = [field[0] for field in table_fields]
record_schema = pa.schema(
[field.with_nullable(True) for field in first_record_batch.select(record_columns).schema]
)

async def flush_to_snowflake(
local_results_file,
records_since_last_flush,
bytes_since_last_flush,
flush_counter: int,
last_date_range: DateRange,
last: bool,
error: Exception | None,
):
logger.info(
"Putting %sfile %s containing %s records with size %s bytes",
"last " if last else "",
flush_counter,
records_since_last_flush,
bytes_since_last_flush,
)

table = snow_stage_table if requires_merge else snow_table

await snow_client.put_file_to_snowflake_table(
local_results_file, data_interval_end_str, table, flush_counter
)
rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)

details.track_done_range(last_date_range, data_interval_start)
details.file_no = flush_counter
heartbeater.set_from_heartbeat_details(details)

writer = JSONLBatchExportWriter(
max_bytes=settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES,
flush_callable=flush_to_snowflake,
records_completed = await run_consumer_loop(
queue=queue,
consumer_cls=SnowflakeConsumer,
producer_task=producer_task,
heartbeater=heartbeater,
heartbeat_details=details,
data_interval_end=data_interval_end,
data_interval_start=data_interval_start,
schema=record_batch_schema,
writer_format=WriterFormat.JSONL,
max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
non_retryable_error_types=NON_RETRYABLE_ERROR_TYPES,
json_columns=known_variant_columns,
snowflake_client=snow_client,
snowflake_table=snow_stage_table if requires_merge else snow_table,
snowflake_table_stage_prefix=data_interval_end_str,
)

async with writer.open_temporary_file(current_flush_counter):
async for record_batch in records_iterator:
record_batch = cast_record_batch_json_columns(record_batch, json_columns=known_variant_columns)

await writer.write_record_batch(record_batch)

details.complete_done_ranges(inputs.data_interval_end)
heartbeater.set_from_heartbeat_details(details)

await snow_client.copy_loaded_files_to_snowflake_table(
snow_stage_table if requires_merge else snow_table, data_interval_end_str
)

if requires_merge:
merge_key = (
("team_id", "INT64"),
Expand All @@ -734,7 +747,7 @@ async def flush_to_snowflake(
merge_key=merge_key,
)

return writer.records_total
return records_completed


@workflow.defn(name="snowflake-export", failure_exception_types=[workflow.NondeterminismError])
Expand Down Expand Up @@ -811,16 +824,6 @@ async def run(self, inputs: SnowflakeBatchExportInputs):
insert_into_snowflake_activity,
insert_inputs,
interval=inputs.interval,
non_retryable_error_types=[
# Raised when we cannot connect to Snowflake.
"DatabaseError",
# Raised by Snowflake when a query cannot be compiled.
# Usually this means we don't have table permissions or something doesn't exist (db, schema).
"ProgrammingError",
# Raised by Snowflake with an incorrect account name.
"ForbiddenError",
# Our own exception when we can't connect to Snowflake, usually due to invalid parameters.
"SnowflakeConnectionError",
],
non_retryable_error_types=NON_RETRYABLE_ERROR_TYPES,
finish_inputs=finish_inputs,
)

0 comments on commit b9795bd

Please sign in to comment.