From 7f06becdbc38289d9ddda6fd60fa812285ed3e2f Mon Sep 17 00:00:00 2001 From: Ross Date: Fri, 13 Dec 2024 09:47:00 +0000 Subject: [PATCH] chore(batch-exports): Add batch export monitoring workflow (#26728) --- posthog/batch_exports/service.py | 16 ++ posthog/batch_exports/sql.py | 19 ++ posthog/temporal/batch_exports/__init__.py | 10 + posthog/temporal/batch_exports/monitoring.py | 227 ++++++++++++++++++ .../temporal/tests/batch_exports/conftest.py | 8 +- .../tests/batch_exports/test_monitoring.py | 201 ++++++++++++++++ 6 files changed, 479 insertions(+), 2 deletions(-) create mode 100644 posthog/temporal/batch_exports/monitoring.py create mode 100644 posthog/temporal/tests/batch_exports/test_monitoring.py diff --git a/posthog/batch_exports/service.py b/posthog/batch_exports/service.py index 59b217b4fc8f2..d17bb3b1b69c3 100644 --- a/posthog/batch_exports/service.py +++ b/posthog/batch_exports/service.py @@ -794,3 +794,19 @@ async def aupdate_batch_export_backfill_status(backfill_id: UUID, status: str) - raise ValueError(f"BatchExportBackfill with id {backfill_id} not found.") return await model.aget() + + +async def aupdate_records_total_count( + batch_export_id: UUID, interval_start: dt.datetime, interval_end: dt.datetime, count: int +) -> int: + """Update the expected records count for a set of batch export runs. + + Typically, there is one batch export run per batch export interval, however + there could be multiple if data has been backfilled. + """ + rows_updated = await BatchExportRun.objects.filter( + batch_export_id=batch_export_id, + data_interval_start=interval_start, + data_interval_end=interval_end, + ).aupdate(records_total_count=count) + return rows_updated diff --git a/posthog/batch_exports/sql.py b/posthog/batch_exports/sql.py index baa0216afdbbc..9a7fd0cea95aa 100644 --- a/posthog/batch_exports/sql.py +++ b/posthog/batch_exports/sql.py @@ -318,3 +318,22 @@ SETTINGS optimize_aggregation_in_order=1 ) """ + +# TODO: is this the best query to use? +EVENT_COUNT_BY_INTERVAL = """ +SELECT + toStartOfInterval(_inserted_at, INTERVAL {interval}) AS interval_start, + interval_start + INTERVAL {interval} AS interval_end, + COUNT(*) as total_count +FROM + events_batch_export_recent( + team_id={team_id}, + interval_start={overall_interval_start}, + interval_end={overall_interval_end}, + include_events={include_events}::Array(String), + exclude_events={exclude_events}::Array(String) + ) AS events +GROUP BY interval_start +ORDER BY interval_start desc +SETTINGS max_replica_delay_for_distributed_queries=1 +""" diff --git a/posthog/temporal/batch_exports/__init__.py b/posthog/temporal/batch_exports/__init__.py index 33c1b200e6a97..a3616f1107c5b 100644 --- a/posthog/temporal/batch_exports/__init__.py +++ b/posthog/temporal/batch_exports/__init__.py @@ -17,6 +17,12 @@ HttpBatchExportWorkflow, insert_into_http_activity, ) +from posthog.temporal.batch_exports.monitoring import ( + BatchExportMonitoringWorkflow, + get_batch_export, + get_event_counts, + update_batch_export_runs, +) from posthog.temporal.batch_exports.noop import NoOpWorkflow, noop_activity from posthog.temporal.batch_exports.postgres_batch_export import ( PostgresBatchExportWorkflow, @@ -54,6 +60,7 @@ SnowflakeBatchExportWorkflow, HttpBatchExportWorkflow, SquashPersonOverridesWorkflow, + BatchExportMonitoringWorkflow, ] ACTIVITIES = [ @@ -76,4 +83,7 @@ update_batch_export_backfill_model_status, wait_for_mutation, wait_for_table, + get_batch_export, + get_event_counts, + update_batch_export_runs, ] diff --git a/posthog/temporal/batch_exports/monitoring.py b/posthog/temporal/batch_exports/monitoring.py new file mode 100644 index 0000000000000..97eaf6c2430d9 --- /dev/null +++ b/posthog/temporal/batch_exports/monitoring.py @@ -0,0 +1,227 @@ +import datetime as dt +import json +from dataclasses import dataclass +from uuid import UUID + +from temporalio import activity, workflow +from temporalio.common import RetryPolicy + +from posthog.batch_exports.models import BatchExport +from posthog.batch_exports.service import aupdate_records_total_count +from posthog.batch_exports.sql import EVENT_COUNT_BY_INTERVAL +from posthog.temporal.batch_exports.base import PostHogWorkflow +from posthog.temporal.common.clickhouse import get_client +from posthog.temporal.common.heartbeat import Heartbeater + + +class BatchExportNotFoundError(Exception): + """Exception raised when batch export is not found.""" + + def __init__(self, batch_export_id: UUID): + super().__init__(f"Batch export with id {batch_export_id} not found") + + +class NoValidBatchExportsFoundError(Exception): + """Exception raised when no valid batch export is found.""" + + def __init__(self, message: str = "No valid batch exports found"): + super().__init__(message) + + +@dataclass +class BatchExportMonitoringInputs: + """Inputs for the BatchExportMonitoringWorkflow. + + Attributes: + batch_export_id: The batch export id to monitor. + """ + + batch_export_id: UUID + + +@dataclass +class BatchExportDetails: + id: UUID + team_id: int + interval: str + exclude_events: list[str] + include_events: list[str] + + +@activity.defn +async def get_batch_export(batch_export_id: UUID) -> BatchExportDetails: + """Fetch a batch export from the database and return its details.""" + batch_export = ( + await BatchExport.objects.filter(id=batch_export_id, model="events", paused=False, deleted=False) + .prefetch_related("destination") + .afirst() + ) + if batch_export is None: + raise BatchExportNotFoundError(batch_export_id) + if batch_export.deleted is True: + raise NoValidBatchExportsFoundError("Batch export has been deleted") + if batch_export.paused is True: + raise NoValidBatchExportsFoundError("Batch export is paused") + if batch_export.model != "events": + raise NoValidBatchExportsFoundError("Batch export model is not 'events'") + if batch_export.interval_time_delta != dt.timedelta(minutes=5): + raise NoValidBatchExportsFoundError( + "Only batch exports with interval of 5 minutes are supported for monitoring at this time." + ) + config = batch_export.destination.config + return BatchExportDetails( + id=batch_export.id, + team_id=batch_export.team_id, + interval=batch_export.interval, + exclude_events=config.get("exclude_events", []), + include_events=config.get("include_events", []), + ) + + +@dataclass +class GetEventCountsInputs: + team_id: int + interval: str + overall_interval_start: str + overall_interval_end: str + exclude_events: list[str] + include_events: list[str] + + +@dataclass +class EventCountsOutput: + interval_start: str + interval_end: str + count: int + + +@dataclass +class GetEventCountsOutputs: + results: list[EventCountsOutput] + + +@activity.defn +async def get_event_counts(inputs: GetEventCountsInputs) -> GetEventCountsOutputs: + """Get the total number of events for a given team over a set of time intervals.""" + + query = EVENT_COUNT_BY_INTERVAL + + interval = inputs.interval + # we check interval is "every 5 minutes" above but double check here + if not interval.startswith("every 5 minutes"): + raise NoValidBatchExportsFoundError( + "Only intervals of 'every 5 minutes' are supported for monitoring at this time." + ) + _, value, unit = interval.split(" ") + interval = f"{value} {unit}" + + query_params = { + "team_id": inputs.team_id, + "interval": interval, + "overall_interval_start": inputs.overall_interval_start, + "overall_interval_end": inputs.overall_interval_end, + "include_events": inputs.include_events, + "exclude_events": inputs.exclude_events, + } + async with Heartbeater(), get_client() as client: + if not await client.is_alive(): + raise ConnectionError("Cannot establish connection to ClickHouse") + + response = await client.read_query(query, query_params) + results = [] + for line in response.decode("utf-8").splitlines(): + interval_start, interval_end, count = line.strip().split("\t") + results.append( + EventCountsOutput(interval_start=interval_start, interval_end=interval_end, count=int(count)) + ) + + return GetEventCountsOutputs(results=results) + + +@dataclass +class UpdateBatchExportRunsInputs: + batch_export_id: UUID + results: list[EventCountsOutput] + + +@activity.defn +async def update_batch_export_runs(inputs: UpdateBatchExportRunsInputs) -> int: + """Update BatchExportRuns with the expected number of events.""" + + total_rows_updated = 0 + async with Heartbeater(): + for result in inputs.results: + total_rows_updated += await aupdate_records_total_count( + batch_export_id=inputs.batch_export_id, + interval_start=dt.datetime.strptime(result.interval_start, "%Y-%m-%d %H:%M:%S").replace(tzinfo=dt.UTC), + interval_end=dt.datetime.strptime(result.interval_end, "%Y-%m-%d %H:%M:%S").replace(tzinfo=dt.UTC), + count=result.count, + ) + activity.logger.info(f"Updated {total_rows_updated} BatchExportRuns") + return total_rows_updated + + +@workflow.defn(name="batch-export-monitoring") +class BatchExportMonitoringWorkflow(PostHogWorkflow): + """Workflow to monitor batch exports. + + We have had some issues with batch exports in the past, where some events + have been missing. The purpose of this workflow is to monitor the status of + batch exports for a given customer by reconciling the number of exported + events with the number of events in ClickHouse for a given interval. + """ + + @staticmethod + def parse_inputs(inputs: list[str]) -> BatchExportMonitoringInputs: + """Parse inputs from the management command CLI.""" + loaded = json.loads(inputs[0]) + return BatchExportMonitoringInputs(**loaded) + + @workflow.run + async def run(self, inputs: BatchExportMonitoringInputs): + """Workflow implementation to monitor batch exports for a given team.""" + # TODO - check if this is the right way to do logging since there seems to be a few different ways + workflow.logger.info( + "Starting batch exports monitoring workflow for batch export id %s", inputs.batch_export_id + ) + + batch_export_details = await workflow.execute_activity( + get_batch_export, + inputs.batch_export_id, + start_to_close_timeout=dt.timedelta(minutes=1), + retry_policy=RetryPolicy( + initial_interval=dt.timedelta(seconds=20), + non_retryable_error_types=["BatchExportNotFoundError", "NoValidBatchExportsFoundError"], + ), + ) + + # time interval to check is not the previous hour but the hour before that + # (just to ensure all recent batch exports have run successfully) + now = dt.datetime.now(tz=dt.UTC) + interval_end = now.replace(minute=0, second=0, microsecond=0) - dt.timedelta(hours=1) + interval_start = interval_end - dt.timedelta(hours=1) + interval_end_str = interval_end.strftime("%Y-%m-%d %H:%M:%S") + interval_start_str = interval_start.strftime("%Y-%m-%d %H:%M:%S") + + total_events = await workflow.execute_activity( + get_event_counts, + GetEventCountsInputs( + team_id=batch_export_details.team_id, + interval=batch_export_details.interval, + overall_interval_start=interval_start_str, + overall_interval_end=interval_end_str, + exclude_events=batch_export_details.exclude_events, + include_events=batch_export_details.include_events, + ), + start_to_close_timeout=dt.timedelta(hours=1), + retry_policy=RetryPolicy(maximum_attempts=3, initial_interval=dt.timedelta(seconds=20)), + heartbeat_timeout=dt.timedelta(minutes=1), + ) + + return await workflow.execute_activity( + update_batch_export_runs, + UpdateBatchExportRunsInputs(batch_export_id=batch_export_details.id, results=total_events.results), + start_to_close_timeout=dt.timedelta(hours=1), + retry_policy=RetryPolicy(maximum_attempts=3, initial_interval=dt.timedelta(seconds=20)), + heartbeat_timeout=dt.timedelta(minutes=1), + ) diff --git a/posthog/temporal/tests/batch_exports/conftest.py b/posthog/temporal/tests/batch_exports/conftest.py index 67c321205a14f..7044d8fe96868 100644 --- a/posthog/temporal/tests/batch_exports/conftest.py +++ b/posthog/temporal/tests/batch_exports/conftest.py @@ -152,8 +152,8 @@ async def create_clickhouse_tables_and_views(clickhouse_client, django_db_setup) from posthog.batch_exports.sql import ( CREATE_EVENTS_BATCH_EXPORT_VIEW, CREATE_EVENTS_BATCH_EXPORT_VIEW_BACKFILL, - CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED, CREATE_EVENTS_BATCH_EXPORT_VIEW_RECENT, + CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED, CREATE_PERSONS_BATCH_EXPORT_VIEW, CREATE_PERSONS_BATCH_EXPORT_VIEW_BACKFILL, ) @@ -211,8 +211,12 @@ def data_interval_start(request, data_interval_end, interval): @pytest.fixture -def data_interval_end(interval): +def data_interval_end(request, interval): """Set a test data interval end.""" + try: + return request.param + except AttributeError: + pass return dt.datetime(2023, 4, 25, 15, 0, 0, tzinfo=dt.UTC) diff --git a/posthog/temporal/tests/batch_exports/test_monitoring.py b/posthog/temporal/tests/batch_exports/test_monitoring.py new file mode 100644 index 0000000000000..cab50c25d3177 --- /dev/null +++ b/posthog/temporal/tests/batch_exports/test_monitoring.py @@ -0,0 +1,201 @@ +import datetime as dt +import uuid + +import pytest +import pytest_asyncio +from temporalio.common import RetryPolicy +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import UnsandboxedWorkflowRunner, Worker + +from posthog import constants +from posthog.batch_exports.models import BatchExportRun +from posthog.temporal.batch_exports.monitoring import ( + BatchExportMonitoringInputs, + BatchExportMonitoringWorkflow, + get_batch_export, + get_event_counts, + update_batch_export_runs, +) +from posthog.temporal.tests.utils.models import ( + acreate_batch_export, + adelete_batch_export, + afetch_batch_export_runs, +) + +pytestmark = [pytest.mark.asyncio, pytest.mark.django_db] + +GENERATE_TEST_DATA_END = dt.datetime.now(tz=dt.UTC).replace( + minute=0, second=0, microsecond=0, tzinfo=dt.UTC +) - dt.timedelta(hours=1) +GENERATE_TEST_DATA_START = GENERATE_TEST_DATA_END - dt.timedelta(hours=1) + + +@pytest_asyncio.fixture +async def batch_export(ateam, temporal_client): + """Provide a batch export for tests, not intended to be used.""" + destination_data = { + "type": "S3", + "config": { + "bucket_name": "a-bucket", + "region": "us-east-1", + "prefix": "a-key", + "aws_access_key_id": "object_storage_root_user", + "aws_secret_access_key": "object_storage_root_password", + }, + } + + batch_export_data = { + "name": "my-production-s3-bucket-destination", + "destination": destination_data, + "interval": "every 5 minutes", + } + + batch_export = await acreate_batch_export( + team_id=ateam.pk, + name=batch_export_data["name"], # type: ignore + destination_data=batch_export_data["destination"], # type: ignore + interval=batch_export_data["interval"], # type: ignore + ) + + yield batch_export + + await adelete_batch_export(batch_export, temporal_client) + + +@pytest_asyncio.fixture +async def generate_batch_export_runs( + generate_test_data, + data_interval_start: dt.datetime, + data_interval_end: dt.datetime, + interval: str, + batch_export, +): + # to keep things simple for now, we assume 5 min interval + if interval != "every 5 minutes": + raise NotImplementedError("Only 5 minute intervals are supported for now. Please update the test.") + + events_created, _ = generate_test_data + + batch_export_runs: list[BatchExportRun] = [] + interval_start = data_interval_start + interval_end = interval_start + dt.timedelta(minutes=5) + while interval_end <= data_interval_end: + run = BatchExportRun( + batch_export_id=batch_export.id, + data_interval_start=interval_start, + data_interval_end=interval_end, + status="completed", + records_completed=len( + [ + e + for e in events_created + if interval_start + <= dt.datetime.fromisoformat(e["inserted_at"]).replace(tzinfo=dt.UTC) + < interval_end + ] + ), + ) + await run.asave() + batch_export_runs.append(run) + interval_start = interval_end + interval_end += dt.timedelta(minutes=5) + + yield + + for run in batch_export_runs: + await run.adelete() + + +async def test_monitoring_workflow_when_no_event_data(batch_export): + workflow_id = str(uuid.uuid4()) + inputs = BatchExportMonitoringInputs(batch_export_id=batch_export.id) + async with await WorkflowEnvironment.start_time_skipping() as activity_environment: + async with Worker( + activity_environment.client, + # TODO - not sure if this is the right task queue + task_queue=constants.BATCH_EXPORTS_TASK_QUEUE, + workflows=[BatchExportMonitoringWorkflow], + activities=[ + get_batch_export, + get_event_counts, + update_batch_export_runs, + ], + workflow_runner=UnsandboxedWorkflowRunner(), + ): + batch_export_runs_updated = await activity_environment.client.execute_workflow( + BatchExportMonitoringWorkflow.run, + inputs, + id=workflow_id, + task_queue=constants.BATCH_EXPORTS_TASK_QUEUE, + retry_policy=RetryPolicy(maximum_attempts=1), + execution_timeout=dt.timedelta(seconds=30), + ) + assert batch_export_runs_updated == 0 + + +@pytest.mark.parametrize( + "data_interval_start", + [GENERATE_TEST_DATA_START], + indirect=True, +) +@pytest.mark.parametrize( + "data_interval_end", + [GENERATE_TEST_DATA_END], + indirect=True, +) +@pytest.mark.parametrize( + "interval", + ["every 5 minutes"], + indirect=True, +) +async def test_monitoring_workflow( + batch_export, + generate_test_data, + data_interval_start, + data_interval_end, + interval, + generate_batch_export_runs, +): + """Test the monitoring workflow with a batch export that has data. + + We generate 2 hours of data between 13:00 and 15:00, and then run the + monitoring workflow at 15:30. The monitoring workflow should check the data + between 14:00 and 15:00, and update the batch export runs. + + We generate some dummy batch export runs based on the event data we + generated and assert that the expected records count matches the records + completed. + """ + workflow_id = str(uuid.uuid4()) + inputs = BatchExportMonitoringInputs(batch_export_id=batch_export.id) + async with await WorkflowEnvironment.start_time_skipping() as activity_environment: + async with Worker( + activity_environment.client, + # TODO - not sure if this is the right task queue + task_queue=constants.BATCH_EXPORTS_TASK_QUEUE, + workflows=[BatchExportMonitoringWorkflow], + activities=[ + get_batch_export, + get_event_counts, + update_batch_export_runs, + ], + workflow_runner=UnsandboxedWorkflowRunner(), + ): + await activity_environment.client.execute_workflow( + BatchExportMonitoringWorkflow.run, + inputs, + id=workflow_id, + task_queue=constants.BATCH_EXPORTS_TASK_QUEUE, + retry_policy=RetryPolicy(maximum_attempts=1), + execution_timeout=dt.timedelta(seconds=30), + ) + + batch_export_runs = await afetch_batch_export_runs(batch_export_id=batch_export.id) + + for run in batch_export_runs: + if run.records_completed == 0: + # TODO: in the actual monitoring activity it would be better to + # update the actual count to 0 rather than None + assert run.records_total_count is None + else: + assert run.records_completed == run.records_total_count