0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-29 11:12:33 +01:00

refactor: Use heartbeat date ranges to track progress (#26094)

This commit is contained in:
Tomás Farías Santana 2024-11-26 11:08:53 +01:00 committed by GitHub
parent a3294feb02
commit 594aad3063
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1567 additions and 428 deletions

View File

@ -1,6 +1,3 @@
posthog/temporal/common/utils.py:0: error: Argument 1 to "abstractclassmethod" has incompatible type "Callable[[HeartbeatDetails, Any], Any]"; expected "Callable[[type[Never], Any], Any]" [arg-type]
posthog/temporal/common/utils.py:0: note: This is likely because "from_activity" has named arguments: "cls". Consider marking them positional-only
posthog/temporal/common/utils.py:0: error: Argument 2 to "__get__" of "classmethod" has incompatible type "type[HeartbeatType]"; expected "type[Never]" [arg-type]
posthog/tasks/exports/ordered_csv_renderer.py:0: error: No return value expected [return-value]
posthog/warehouse/models/ssh_tunnel.py:0: error: Incompatible types in assignment (expression has type "NoEncryption", variable has type "BestAvailableEncryption") [assignment]
posthog/temporal/data_imports/pipelines/sql_database_v2/schema_types.py:0: error: Statement is unreachable [unreachable]
@ -829,8 +826,6 @@ posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0:
posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_execute_async_calls" (hint: "_execute_async_calls: list[<type>] = ...") [var-annotated]
posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_cursors" (hint: "_cursors: list[<type>] = ...") [var-annotated]
posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: List item 0 has incompatible type "tuple[str, str, int, int, int, int, str, int]"; expected "tuple[str, str, int, int, str, str, str, str]" [list-item]
posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py:0: error: "tuple[Any, ...]" has no attribute "last_uploaded_part_timestamp" [attr-defined]
posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py:0: error: "tuple[Any, ...]" has no attribute "upload_state" [attr-defined]
posthog/migrations/0237_remove_timezone_from_teams.py:0: error: Argument 2 to "RunPython" has incompatible type "Callable[[Migration, Any], None]"; expected "_CodeCallable | None" [arg-type]
posthog/migrations/0228_fix_tile_layouts.py:0: error: Argument 2 to "RunPython" has incompatible type "Callable[[Migration, Any], None]"; expected "_CodeCallable | None" [arg-type]
posthog/api/plugin_log_entry.py:0: error: Name "timezone.datetime" is not defined [name-defined]

View File

@ -3,6 +3,7 @@ import collections
import collections.abc
import dataclasses
import datetime as dt
import operator
import typing
import uuid
from string import Template
@ -361,8 +362,8 @@ def start_produce_batch_export_record_batches(
model_name: str,
is_backfill: bool,
team_id: int,
interval_start: str | None,
interval_end: str,
full_range: tuple[dt.datetime | None, dt.datetime],
done_ranges: list[tuple[dt.datetime, dt.datetime]],
fields: list[BatchExportField] | None = None,
destination_default_fields: list[BatchExportField] | None = None,
**parameters,
@ -386,7 +387,7 @@ def start_produce_batch_export_record_batches(
fields = destination_default_fields
if model_name == "persons":
if is_backfill and interval_start is None:
if is_backfill and full_range[0] is None:
view = SELECT_FROM_PERSONS_VIEW_BACKFILL
else:
view = SELECT_FROM_PERSONS_VIEW
@ -420,26 +421,112 @@ def start_produce_batch_export_record_batches(
view = query_template.substitute(fields=query_fields)
if interval_start is not None:
parameters["interval_start"] = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S")
parameters["interval_end"] = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S")
parameters["team_id"] = team_id
extra_query_parameters = parameters.pop("extra_query_parameters", {}) or {}
parameters = {**parameters, **extra_query_parameters}
queue = RecordBatchQueue(max_size_bytes=settings.BATCH_EXPORT_BUFFER_QUEUE_MAX_SIZE_BYTES)
query_id = uuid.uuid4()
produce_task = asyncio.create_task(
client.aproduce_query_as_arrow_record_batches(
view, queue=queue, query_parameters=parameters, query_id=str(query_id)
produce_batch_export_record_batches_from_range(
client=client,
query=view,
full_range=full_range,
done_ranges=done_ranges,
queue=queue,
query_parameters=parameters,
)
)
return queue, produce_task
async def produce_batch_export_record_batches_from_range(
client: ClickHouseClient,
query: str,
full_range: tuple[dt.datetime | None, dt.datetime],
done_ranges: collections.abc.Sequence[tuple[dt.datetime, dt.datetime]],
queue: RecordBatchQueue,
query_parameters: dict[str, typing.Any],
):
"""Produce all record batches into `queue` required to complete `full_range`.
This function will skip over any already completed `done_ranges`.
"""
for interval_start, interval_end in generate_query_ranges(full_range, done_ranges):
if interval_start is not None:
query_parameters["interval_start"] = interval_start.strftime("%Y-%m-%d %H:%M:%S.%f")
query_parameters["interval_end"] = interval_end.strftime("%Y-%m-%d %H:%M:%S.%f")
query_id = uuid.uuid4()
await client.aproduce_query_as_arrow_record_batches(
query, queue=queue, query_parameters=query_parameters, query_id=str(query_id)
)
def generate_query_ranges(
remaining_range: tuple[dt.datetime | None, dt.datetime],
done_ranges: collections.abc.Sequence[tuple[dt.datetime, dt.datetime]],
) -> typing.Iterator[tuple[dt.datetime | None, dt.datetime]]:
"""Recursively yield ranges of dates that need to be queried.
There are essentially 3 scenarios we are expecting:
1. The batch export just started, so we expect `done_ranges` to be an empty
list, and thus should return the `remaining_range`.
2. The batch export crashed mid-execution, so we have some `done_ranges` that
do not completely add up to the full range. In this case we need to yield
ranges in between all the done ones.
3. The batch export crashed right after we finish, so we have a full list of
`done_ranges` adding up to the `remaining_range`. In this case we should not
yield anything.
Case 1 is fairly trivial and we can simply return `remaining_range` if we get
an empty `done_ranges`.
Case 2 is more complicated and we can expect that the ranges produced by this
function will lead to duplicate events selected, as our batch export query is
inclusive in the lower bound. Since multiple rows may have the same
`inserted_at` we cannot simply skip an `inserted_at` value, as there may be a
row that hasn't been exported as it with the same `inserted_at` as a row that
has been exported. So this function will return ranges with `inserted_at`
values that were already exported for at least one event. Ideally, this is
*only* one event, but we can never be certain.
"""
if len(done_ranges) == 0:
yield remaining_range
return
epoch = dt.datetime.fromtimestamp(0, tz=dt.UTC)
list_done_ranges: list[tuple[dt.datetime, dt.datetime]] = list(done_ranges)
list_done_ranges.sort(key=operator.itemgetter(0))
while True:
try:
next_range: tuple[dt.datetime | None, dt.datetime] = list_done_ranges.pop(0)
except IndexError:
if remaining_range[0] != remaining_range[1]:
# If they were equal it would mean we have finished.
yield remaining_range
return
else:
candidate_end_at = next_range[0] if next_range[0] is not None else epoch
candidate_start_at = remaining_range[0]
remaining_range = (next_range[1], remaining_range[1])
if candidate_start_at is not None and candidate_start_at >= candidate_end_at:
# We have landed within a done range.
continue
if candidate_start_at is None and candidate_end_at == epoch:
# We have landed within the first done range of a backfill.
continue
yield (candidate_start_at, candidate_end_at)
async def raise_on_produce_task_failure(produce_task: asyncio.Task) -> None:
"""Raise `RecordBatchProducerError` if a produce task failed.

View File

@ -49,13 +49,15 @@ from posthog.temporal.batch_exports.utils import (
cast_record_batch_json_columns,
set_status_to_running_task,
)
from posthog.temporal.batch_exports.heartbeat import (
BatchExportRangeHeartbeatDetails,
DateRange,
should_resume_from_activity_heartbeat,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import configure_temporal_worker_logger
from posthog.temporal.common.utils import (
BatchExportHeartbeatDetails,
should_resume_from_activity_heartbeat,
)
logger = structlog.get_logger()
@ -113,7 +115,7 @@ def get_bigquery_fields_from_record_schema(
@dataclasses.dataclass
class BigQueryHeartbeatDetails(BatchExportHeartbeatDetails):
class BigQueryHeartbeatDetails(BatchExportRangeHeartbeatDetails):
"""The BigQuery batch export details included in every heartbeat."""
pass
@ -366,12 +368,11 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")
should_resume, details = await should_resume_from_activity_heartbeat(activity, BigQueryHeartbeatDetails, logger)
_, details = await should_resume_from_activity_heartbeat(activity, BigQueryHeartbeatDetails)
if details is None:
details = BigQueryHeartbeatDetails()
if should_resume is True and details is not None:
data_interval_start: str | None = details.last_inserted_at.isoformat()
else:
data_interval_start = inputs.data_interval_start
done_ranges: list[DateRange] = details.done_ranges
model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
@ -392,13 +393,18 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
extra_query_parameters = model["values"] if model is not None else {}
fields = model["fields"] if model is not None else None
data_interval_start = (
dt.datetime.fromisoformat(inputs.data_interval_start) if inputs.data_interval_start else None
)
data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end)
full_range = (data_interval_start, data_interval_end)
queue, produce_task = start_produce_batch_export_record_batches(
client=client,
model_name=model_name,
is_backfill=inputs.is_backfill,
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
full_range=full_range,
done_ranges=done_ranges,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
@ -490,7 +496,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
records_since_last_flush: int,
bytes_since_last_flush: int,
flush_counter: int,
last_inserted_at,
last_date_range,
last: bool,
error: Exception | None,
):
@ -508,7 +514,8 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)
heartbeater.details = (str(last_inserted_at),)
details.track_done_range(last_date_range, data_interval_start)
heartbeater.set_from_heartbeat_details(details)
flush_tasks = []
while not queue.empty() or not produce_task.done():
@ -535,6 +542,9 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
await raise_on_produce_task_failure(produce_task)
await logger.adebug("Successfully consumed all record batches")
details.complete_done_ranges(inputs.data_interval_end)
heartbeater.set_from_heartbeat_details(details)
records_total = functools.reduce(operator.add, (task.result() for task in flush_tasks))
if requires_merge:

View File

@ -0,0 +1,215 @@
import typing
import datetime as dt
import collections.abc
import dataclasses
import structlog
from posthog.temporal.common.heartbeat import (
HeartbeatDetails,
HeartbeatParseError,
EmptyHeartbeatError,
NotEnoughHeartbeatValuesError,
)
DateRange = tuple[dt.datetime, dt.datetime]
logger = structlog.get_logger()
@dataclasses.dataclass
class BatchExportRangeHeartbeatDetails(HeartbeatDetails):
"""Details included in every batch export heartbeat.
Attributes:
done_ranges: Date ranges that have been successfully exported.
_remaining: Anything else in the activity details.
"""
done_ranges: list[DateRange] = dataclasses.field(default_factory=list)
_remaining: collections.abc.Sequence[typing.Any] = dataclasses.field(default_factory=tuple)
@classmethod
def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]:
"""Deserialize this from Temporal activity details.
We expect done ranges to be available in the first index of remaining
values. Moreover, we expect datetime values to be ISO-formatted strings.
"""
done_ranges: list[DateRange] = []
remaining = super().deserialize_details(details)
if len(remaining["_remaining"]) == 0:
return {"done_ranges": done_ranges, **remaining}
first_detail = remaining["_remaining"][0]
remaining["_remaining"] = remaining["_remaining"][1:]
for date_str_tuple in first_detail:
try:
range_start, range_end = date_str_tuple
datetime_bounds = (
dt.datetime.fromisoformat(range_start),
dt.datetime.fromisoformat(range_end),
)
except (TypeError, ValueError) as e:
raise HeartbeatParseError("done_ranges") from e
done_ranges.append(datetime_bounds)
return {"done_ranges": done_ranges, **remaining}
def serialize_details(self) -> tuple[typing.Any, ...]:
"""Serialize this into a tuple.
Each datetime from `self.done_ranges` must be cast to string as values must
be JSON-serializable.
"""
serialized_done_ranges = [
(start.isoformat() if start is not None else start, end.isoformat()) for (start, end) in self.done_ranges
]
serialized_parent_details = super().serialize_details()
return (*serialized_parent_details[:-1], serialized_done_ranges, self._remaining)
@property
def empty(self) -> bool:
return len(self.done_ranges) == 0
def track_done_range(
self, done_range: DateRange, data_interval_start_input: str | dt.datetime | None, merge: bool = True
):
"""Track a range of datetime values that has been exported successfully.
If this is the first `done_range` then we override the beginning of the
range to ensure it covers the range from `data_interval_start_input`.
Arguments:
done_range: A date range of values that have been exported.
data_interval_start_input: The `data_interval_start` input passed to
the batch export
merge: Whether to merge the new range with existing ones.
"""
if self.empty is True:
if data_interval_start_input is None:
data_interval_start = dt.datetime.fromtimestamp(0, tz=dt.UTC)
elif isinstance(data_interval_start_input, str):
data_interval_start = dt.datetime.fromisoformat(data_interval_start_input)
else:
data_interval_start = data_interval_start_input
done_range = (data_interval_start, done_range[1])
self.insert_done_range(done_range, merge=merge)
def insert_done_range(self, done_range: DateRange, merge: bool = True):
"""Insert a date range into `self.done_ranges` in order."""
for index, range in enumerate(self.done_ranges, start=0):
if done_range[0] > range[1]:
continue
# We have found the index where this date range should go in.
if done_range[0] == range[1]:
self.done_ranges.insert(index + 1, done_range)
else:
self.done_ranges.insert(index, done_range)
break
else:
# Date range should go at the end
self.done_ranges.append(done_range)
if merge:
self.merge_done_ranges()
def merge_done_ranges(self):
"""Merge as many date ranges together as possible in `self.done_ranges`.
This method looks for ranges whose opposite ends are touching and merges
them together. Notice that this method does not have enough information
to merge ranges that are not touching.
"""
marked_for_deletion = set()
for index, range in enumerate(self.done_ranges, start=0):
if index in marked_for_deletion:
continue
try:
next_range = self.done_ranges[index + 1]
except IndexError:
continue
if next_range[0] == range[1]:
# Touching start of next range with end of range.
# End of next range set as end of existing range.
# Next range marked for deletion as it's now covered by range.
self.done_ranges[index] = (range[0], next_range[1])
marked_for_deletion.add(index + 1)
for index in marked_for_deletion:
self.done_ranges.pop(index)
def complete_done_ranges(self, data_interval_end_input: str | dt.datetime):
"""Complete the entire range covered by the batch export.
This is meant to be called at the end of a batch export to ensure
`self.done_ranges` covers the entire batch period from whichever was the
first range tracked until `data_interval_end_input`.
All ranges will be essentially merged into one (well, replaced by one)
covering everything, so it is very important to only call this once
everything is done.
"""
if isinstance(data_interval_end_input, str):
data_interval_end = dt.datetime.fromisoformat(data_interval_end_input)
else:
data_interval_end = data_interval_end_input
self.done_ranges = [(self.done_ranges[0][0], data_interval_end)]
HeartbeatType = typing.TypeVar("HeartbeatType", bound=HeartbeatDetails)
async def should_resume_from_activity_heartbeat(
activity, heartbeat_type: type[HeartbeatType]
) -> tuple[bool, HeartbeatType | None]:
"""Check if a batch export should resume from an activity's heartbeat details.
We understand that a batch export should resume any time that we receive heartbeat details and
those details can be correctly parsed. However, the decision is ultimately up to the batch export
activity to decide if it must resume and how to do so.
Returns:
A tuple with the first element indicating if the batch export should resume. If the first element
is True, the second tuple element will be the heartbeat details themselves, otherwise None.
"""
try:
heartbeat_details = heartbeat_type.from_activity(activity)
except EmptyHeartbeatError:
# We don't log this as it's the expected exception when heartbeat is empty.
heartbeat_details = None
received = False
except NotEnoughHeartbeatValuesError:
heartbeat_details = None
received = False
await logger.awarning("Details from previous activity execution did not contain the expected amount of values")
except HeartbeatParseError:
heartbeat_details = None
received = False
await logger.awarning("Details from previous activity execution could not be parsed.")
except Exception:
# We should start from the beginning, but we make a point to log unexpected errors.
# Ideally, any new exceptions should be added to the previous blocks after the first time and we will never land here.
heartbeat_details = None
received = False
await logger.aexception("Did not receive details from previous activity Execution due to an unexpected error")
else:
received = True
await logger.adebug(
f"Received details from previous activity: {heartbeat_details}",
)
return received, heartbeat_details

View File

@ -47,7 +47,11 @@ from posthog.temporal.batch_exports.utils import (
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import configure_temporal_worker_logger
from posthog.temporal.common.utils import BatchExportHeartbeatDetails, should_resume_from_activity_heartbeat
from posthog.temporal.batch_exports.heartbeat import (
BatchExportRangeHeartbeatDetails,
DateRange,
should_resume_from_activity_heartbeat,
)
def remove_escaped_whitespace_recursive(value):
@ -273,7 +277,7 @@ def get_redshift_fields_from_record_schema(
@dataclasses.dataclass
class RedshiftHeartbeatDetails(BatchExportHeartbeatDetails):
class RedshiftHeartbeatDetails(BatchExportRangeHeartbeatDetails):
"""The Redshift batch export details included in every heartbeat."""
pass
@ -285,6 +289,9 @@ async def insert_records_to_redshift(
schema: str | None,
table: str,
heartbeater: Heartbeater,
heartbeat_details: RedshiftHeartbeatDetails,
data_interval_start: dt.datetime | None,
data_interval_end: dt.datetime,
batch_size: int = 100,
use_super: bool = False,
known_super_columns: list[str] | None = None,
@ -352,7 +359,11 @@ async def insert_records_to_redshift(
# the byte size of each batch the way things are currently written. We can revisit this
# in the future if we decide it's useful enough.
batch_start_inserted_at = None
async for record, _inserted_at in records_iterator:
if batch_start_inserted_at is None:
batch_start_inserted_at = _inserted_at
for column in columns:
if known_super_columns is not None and column in known_super_columns:
record[column] = json.dumps(record[column], ensure_ascii=False)
@ -362,12 +373,24 @@ async def insert_records_to_redshift(
continue
await flush_to_redshift(batch)
heartbeater.details = (str(_inserted_at),)
last_date_range = (batch_start_inserted_at, _inserted_at)
heartbeat_details.track_done_range(last_date_range, data_interval_start)
heartbeater.set_from_heartbeat_details(heartbeat_details)
batch_start_inserted_at = None
batch = []
if len(batch) > 0:
if len(batch) > 0 and batch_start_inserted_at:
await flush_to_redshift(batch)
heartbeater.details = (str(_inserted_at),)
last_date_range = (batch_start_inserted_at, _inserted_at)
heartbeat_details.track_done_range(last_date_range, data_interval_start)
heartbeater.set_from_heartbeat_details(heartbeat_details)
heartbeat_details.complete_done_ranges(data_interval_end)
heartbeater.set_from_heartbeat_details(heartbeat_details)
return total_rows_exported
@ -420,12 +443,11 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")
should_resume, details = await should_resume_from_activity_heartbeat(activity, RedshiftHeartbeatDetails, logger)
_, details = await should_resume_from_activity_heartbeat(activity, RedshiftHeartbeatDetails)
if details is None:
details = RedshiftHeartbeatDetails()
if should_resume is True and details is not None:
data_interval_start: str | None = details.last_inserted_at.isoformat()
else:
data_interval_start = inputs.data_interval_start
done_ranges: list[DateRange] = details.done_ranges
model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
@ -446,13 +468,19 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
extra_query_parameters = model["values"] if model is not None else {}
fields = model["fields"] if model is not None else None
data_interval_start = (
dt.datetime.fromisoformat(inputs.data_interval_start) if inputs.data_interval_start else None
)
data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end)
full_range = (data_interval_start, data_interval_end)
queue, produce_task = start_produce_batch_export_record_batches(
client=client,
model_name=model_name,
is_backfill=inputs.is_backfill,
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
full_range=full_range,
done_ranges=done_ranges,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
@ -545,7 +573,12 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
# TODO: We should be able to save a json.loads here.
record[column] = remove_escaped_whitespace_recursive(json.loads(record[column]))
return record, row["_inserted_at"]
if isinstance(row["_inserted_at"], int):
inserted_at = dt.datetime.fromtimestamp(row["_inserted_at"])
else:
inserted_at = row["_inserted_at"]
return record, inserted_at
async def record_generator() -> (
collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None]
@ -574,6 +607,9 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
heartbeater=heartbeater,
use_super=properties_type == "SUPER",
known_super_columns=known_super_columns,
heartbeat_details=details,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
)
if requires_merge:

View File

@ -6,6 +6,7 @@ import io
import json
import posixpath
import typing
import collections.abc
import aioboto3
import botocore.exceptions
@ -52,6 +53,12 @@ from posthog.temporal.batch_exports.utils import (
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.batch_exports.heartbeat import (
BatchExportRangeHeartbeatDetails,
DateRange,
HeartbeatParseError,
should_resume_from_activity_heartbeat,
)
def get_allowed_template_variables(inputs) -> dict[str, str]:
@ -379,22 +386,51 @@ class S3MultiPartUpload:
return False
class HeartbeatDetails(typing.NamedTuple):
@dataclasses.dataclass
class S3HeartbeatDetails(BatchExportRangeHeartbeatDetails):
"""This tuple allows us to enforce a schema on the Heartbeat details.
Attributes:
last_uploaded_part_timestamp: The timestamp of the last part we managed to upload.
upload_state: State to continue a S3MultiPartUpload when activity execution resumes.
"""
last_uploaded_part_timestamp: str
upload_state: S3MultiPartUploadState
upload_state: S3MultiPartUploadState | None = None
@classmethod
def from_activity_details(cls, details):
last_uploaded_part_timestamp = details[0]
upload_state = S3MultiPartUploadState(*details[1])
return cls(last_uploaded_part_timestamp, upload_state)
def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]:
"""Attempt to initialize HeartbeatDetails from an activity's details."""
upload_state = None
remaining = super().deserialize_details(details)
if len(remaining["_remaining"]) == 0:
return {"upload_state": upload_state, **remaining}
first_detail = remaining["_remaining"][0]
remaining["_remaining"] = remaining["_remaining"][1:]
if first_detail is None:
return {"upload_state": None, **remaining}
try:
upload_state = S3MultiPartUploadState(*first_detail)
except (TypeError, ValueError) as e:
raise HeartbeatParseError("upload_state") from e
return {"upload_state": upload_state, **remaining}
def serialize_details(self) -> tuple[typing.Any, ...]:
"""Attempt to initialize HeartbeatDetails from an activity's details."""
serialized_parent_details = super().serialize_details()
return (*serialized_parent_details[:-1], self.upload_state, self._remaining)
def append_upload_state(self, upload_state: S3MultiPartUploadState):
if self.upload_state is None:
self.upload_state = upload_state
current_parts = {part["PartNumber"] for part in self.upload_state.parts}
for part in upload_state.parts:
if part["PartNumber"] not in current_parts:
self.upload_state.parts.append(part)
@dataclasses.dataclass
@ -428,7 +464,9 @@ class S3InsertInputs:
batch_export_schema: BatchExportSchema | None = None
async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str | None]:
async def initialize_and_resume_multipart_upload(
inputs: S3InsertInputs,
) -> tuple[S3MultiPartUpload, S3HeartbeatDetails]:
"""Initialize a S3MultiPartUpload and resume it from a hearbeat state if available."""
logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="S3")
key = get_s3_key(inputs)
@ -444,34 +482,16 @@ async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tupl
endpoint_url=inputs.endpoint_url,
)
details = activity.info().heartbeat_details
_, details = await should_resume_from_activity_heartbeat(activity, S3HeartbeatDetails)
if details is None:
details = S3HeartbeatDetails()
try:
interval_start, upload_state = HeartbeatDetails.from_activity_details(details)
except IndexError:
# This is the error we expect when no details as the sequence will be empty.
interval_start = inputs.data_interval_start
await logger.adebug(
"Did not receive details from previous activity Execution. Export will start from the beginning %s",
interval_start,
)
except Exception:
# We still start from the beginning, but we make a point to log unexpected errors.
# Ideally, any new exceptions should be added to the previous block after the first time and we will never land here.
interval_start = inputs.data_interval_start
await logger.awarning(
"Did not receive details from previous activity Execution due to an unexpected error. Export will start from the beginning %s",
interval_start,
)
else:
await logger.ainfo(
"Received details from previous activity. Export will attempt to resume from %s",
interval_start,
)
s3_upload.continue_from_state(upload_state)
if details.upload_state:
s3_upload.continue_from_state(details.upload_state)
if inputs.compression == "brotli":
# Even if we receive details we cannot resume a brotli compressed upload as we have lost the compressor state.
# Even if we receive details we cannot resume a brotli compressed upload as
# we have lost the compressor state.
interval_start = inputs.data_interval_start
await logger.ainfo(
@ -480,7 +500,7 @@ async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tupl
)
await s3_upload.abort()
return s3_upload, interval_start
return s3_upload, details
def s3_default_fields() -> list[BatchExportField]:
@ -527,7 +547,14 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted:
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")
s3_upload, interval_start = await initialize_and_resume_multipart_upload(inputs)
s3_upload, details = await initialize_and_resume_multipart_upload(inputs)
# TODO: Switch to single-producer multiple consumer
done_ranges: list[DateRange] = details.done_ranges
if done_ranges:
data_interval_start: str | None = done_ranges[-1][1].isoformat()
else:
data_interval_start = inputs.data_interval_start
model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
@ -541,7 +568,7 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted:
model=model,
client=client,
team_id=inputs.team_id,
interval_start=interval_start,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
@ -562,13 +589,13 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted:
records_since_last_flush: int,
bytes_since_last_flush: int,
flush_counter: int,
last_inserted_at: dt.datetime,
last_date_range: DateRange,
last: bool,
error: Exception | None,
):
if error is not None:
await logger.adebug("Error while writing part %d", s3_upload.part_number + 1, exc_info=error)
await logger.awarn(
await logger.awarning(
"An error was detected while writing part %d. Partial part will not be uploaded in case it can be retried.",
s3_upload.part_number + 1,
)
@ -587,7 +614,9 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted:
rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)
heartbeater.details = (str(last_inserted_at), s3_upload.to_state())
details.track_done_range(last_date_range, data_interval_start)
details.append_upload_state(s3_upload.to_state())
heartbeater.set_from_heartbeat_details(details)
first_record_batch = cast_record_batch_json_columns(first_record_batch)
column_names = first_record_batch.column_names
@ -618,6 +647,9 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted:
await writer.write_record_batch(record_batch)
details.complete_done_ranges(inputs.data_interval_end)
heartbeater.set_from_heartbeat_details(details)
records_completed = writer.records_total
await s3_upload.complete()

View File

@ -51,10 +51,10 @@ from posthog.temporal.batch_exports.utils import (
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.common.utils import (
BatchExportHeartbeatDetails,
from posthog.temporal.batch_exports.heartbeat import (
BatchExportRangeHeartbeatDetails,
DateRange,
HeartbeatParseError,
NotEnoughHeartbeatValuesError,
should_resume_from_activity_heartbeat,
)
@ -90,28 +90,38 @@ class SnowflakeRetryableConnectionError(Exception):
@dataclasses.dataclass
class SnowflakeHeartbeatDetails(BatchExportHeartbeatDetails):
class SnowflakeHeartbeatDetails(BatchExportRangeHeartbeatDetails):
"""The Snowflake batch export details included in every heartbeat.
Attributes:
file_no: The file number of the last file we managed to upload.
"""
file_no: int
file_no: int = 0
@classmethod
def from_activity(cls, activity):
details = BatchExportHeartbeatDetails.from_activity(activity)
def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]:
"""Attempt to initialize HeartbeatDetails from an activity's details."""
file_no = 0
remaining = super().deserialize_details(details)
if details.total_details < 2:
raise NotEnoughHeartbeatValuesError(details.total_details, 2)
if len(remaining["_remaining"]) == 0:
return {"file_no": 0, **remaining}
first_detail = remaining["_remaining"][0]
remaining["_remaining"] = remaining["_remaining"][1:]
try:
file_no = int(details._remaining[0])
file_no = int(first_detail)
except (TypeError, ValueError) as e:
raise HeartbeatParseError("file_no") from e
return cls(last_inserted_at=details.last_inserted_at, file_no=file_no, _remaining=details._remaining[2:])
return {"file_no": file_no, **remaining}
def serialize_details(self) -> tuple[typing.Any, ...]:
"""Attempt to initialize HeartbeatDetails from an activity's details."""
serialized_parent_details = super().serialize_details()
return (*serialized_parent_details[:-1], self.file_no, self._remaining)
@dataclasses.dataclass
@ -579,16 +589,17 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")
should_resume, details = await should_resume_from_activity_heartbeat(
activity, SnowflakeHeartbeatDetails, logger
)
_, details = await should_resume_from_activity_heartbeat(activity, SnowflakeHeartbeatDetails)
if details is None:
details = SnowflakeHeartbeatDetails()
if should_resume is True and details is not None:
data_interval_start: str | None = details.last_inserted_at.isoformat()
current_flush_counter = details.file_no
done_ranges: list[DateRange] = details.done_ranges
if done_ranges:
data_interval_start: str | None = done_ranges[-1][1].isoformat()
else:
data_interval_start = inputs.data_interval_start
current_flush_counter = 0
current_flush_counter = details.file_no
rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()
@ -670,7 +681,7 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
records_since_last_flush,
bytes_since_last_flush,
flush_counter: int,
last_inserted_at,
last_date_range: DateRange,
last: bool,
error: Exception | None,
):
@ -690,7 +701,9 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)
heartbeater.details = (str(last_inserted_at), flush_counter)
details.track_done_range(last_date_range, data_interval_start)
details.file_no = flush_counter
heartbeater.set_from_heartbeat_details(details)
writer = JSONLBatchExportWriter(
max_bytes=settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES,
@ -703,6 +716,9 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor
await writer.write_record_batch(record_batch)
details.complete_done_ranges(inputs.data_interval_end)
heartbeater.set_from_heartbeat_details(details)
await snow_client.copy_loaded_files_to_snowflake_table(
snow_stage_table if requires_merge else snow_table, data_interval_end_str
)

View File

@ -17,6 +17,8 @@ import pyarrow as pa
import pyarrow.parquet as pq
import structlog
from posthog.temporal.batch_exports.heartbeat import DateRange
logger = structlog.get_logger()
@ -247,7 +249,6 @@ class BatchExportTemporaryFile:
self.records_since_last_reset = 0
LastInsertedAt = dt.datetime
IsLast = bool
RecordsSinceLastFlush = int
BytesSinceLastFlush = int
@ -258,7 +259,7 @@ FlushCallable = collections.abc.Callable[
RecordsSinceLastFlush,
BytesSinceLastFlush,
FlushCounter,
LastInsertedAt,
DateRange,
IsLast,
Exception | None,
],
@ -318,7 +319,9 @@ class BatchExportWriter(abc.ABC):
def reset_writer_tracking(self):
"""Reset this writer's tracking state."""
self.last_inserted_at: dt.datetime | None = None
self.start_at_since_last_flush: dt.datetime | None = None
self.end_at_since_last_flush: dt.datetime | None = None
self.flushed_date_ranges: list[DateRange] = []
self.records_total = 0
self.records_since_last_flush = 0
self.bytes_total = 0
@ -326,6 +329,13 @@ class BatchExportWriter(abc.ABC):
self.flush_counter = 0
self.error = None
@property
def date_range_since_last_flush(self) -> DateRange | None:
if self.start_at_since_last_flush is not None and self.end_at_since_last_flush is not None:
return (self.start_at_since_last_flush, self.end_at_since_last_flush)
else:
return None
@contextlib.asynccontextmanager
async def open_temporary_file(self, current_flush_counter: int = 0):
"""Explicitly open the temporary file this writer is writing to.
@ -352,12 +362,12 @@ class BatchExportWriter(abc.ABC):
finally:
self.track_bytes_written(temp_file)
if self.last_inserted_at is not None and self.bytes_since_last_flush > 0:
if self.bytes_since_last_flush > 0:
# `bytes_since_last_flush` should be 0 unless:
# 1. The last batch wasn't flushed as it didn't reach `max_bytes`.
# 2. The last batch was flushed but there was another write after the last call to
# `write_record_batch`. For example, footer bytes.
await self.flush(self.last_inserted_at, is_last=True)
await self.flush(is_last=True)
self._batch_export_file = None
@ -394,24 +404,38 @@ class BatchExportWriter(abc.ABC):
async def write_record_batch(self, record_batch: pa.RecordBatch, flush: bool = True) -> None:
"""Issue a record batch write tracking progress and flushing if required."""
record_batch = record_batch.sort_by("_inserted_at")
last_inserted_at = record_batch.column("_inserted_at")[-1].as_py()
if self.start_at_since_last_flush is None:
raw_start_at = record_batch.column("_inserted_at")[0].as_py()
if isinstance(raw_start_at, int):
try:
self.start_at_since_last_flush = dt.datetime.fromtimestamp(raw_start_at, tz=dt.UTC)
except Exception:
raise
else:
self.start_at_since_last_flush = raw_start_at
raw_end_at = record_batch.column("_inserted_at")[-1].as_py()
if isinstance(raw_end_at, int):
self.end_at_since_last_flush = dt.datetime.fromtimestamp(raw_end_at, tz=dt.UTC)
else:
self.end_at_since_last_flush = raw_end_at
column_names = record_batch.column_names
column_names.pop(column_names.index("_inserted_at"))
await asyncio.to_thread(self._write_record_batch, record_batch.select(column_names))
self.last_inserted_at = last_inserted_at
self.track_records_written(record_batch)
self.track_bytes_written(self.batch_export_file)
if flush and self.should_flush():
await self.flush(last_inserted_at)
await self.flush()
def should_flush(self) -> bool:
return self.bytes_since_last_flush >= self.max_bytes
async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None:
async def flush(self, is_last: bool = False) -> None:
"""Call the provided `flush_callable` and reset underlying file.
The underlying batch export temporary file will be reset after calling `flush_callable`.
@ -421,12 +445,15 @@ class BatchExportWriter(abc.ABC):
self.batch_export_file.seek(0)
if self.date_range_since_last_flush is not None:
self.flushed_date_ranges.append(self.date_range_since_last_flush)
await self.flush_callable(
self.batch_export_file,
self.records_since_last_flush,
self.bytes_since_last_flush,
self.flush_counter,
last_inserted_at,
self.flushed_date_ranges[-1],
is_last,
self.error,
)
@ -435,6 +462,8 @@ class BatchExportWriter(abc.ABC):
self.records_since_last_flush = 0
self.bytes_since_last_flush = 0
self.flush_counter += 1
self.start_at_since_last_flush = None
self.end_at_since_last_flush = None
class JSONLBatchExportWriter(BatchExportWriter):

View File

@ -1,5 +1,8 @@
import asyncio
import typing
import dataclasses
import collections.abc
import abc
from temporalio import activity
@ -20,7 +23,7 @@ class Heartbeater:
maintained while in the context manager to avoid garbage collection.
"""
def __init__(self, details: tuple[typing.Any, ...] = (), factor: int = 12):
def __init__(self, details: tuple[typing.Any, ...] = (), factor: int = 120):
self._details: tuple[typing.Any, ...] = details
self.factor = factor
self.heartbeat_task: asyncio.Task | None = None
@ -36,6 +39,10 @@ class Heartbeater:
"""Set tuple to be passed as heartbeat details."""
self._details = details
def set_from_heartbeat_details(self, details: "HeartbeatDetails") -> None:
"""Set `HeartbeatDetails` to be passed as heartbeat details."""
self._details = tuple(details.serialize_details())
async def __aenter__(self):
"""Enter managed heartbeatting context."""
@ -82,3 +89,116 @@ class Heartbeater:
self.heartbeat_task = None
self.heartbeat_on_shutdown_task = None
class EmptyHeartbeatError(Exception):
"""Raised when an activity heartbeat is empty.
This is also the error we expect when no heartbeatting is happening, as the sequence will be empty.
"""
def __init__(self):
super().__init__(f"Heartbeat details sequence is empty")
class NotEnoughHeartbeatValuesError(Exception):
"""Raised when an activity heartbeat doesn't contain the right amount of values we expect."""
def __init__(self, details_len: int, expected: int):
super().__init__(f"Not enough values in heartbeat details (expected {expected}, got {details_len})")
class HeartbeatParseError(Exception):
"""Raised when an activity heartbeat cannot be parsed into it's expected types."""
def __init__(self, field: str):
super().__init__(f"Parsing {field} from heartbeat details encountered an error")
@dataclasses.dataclass
class HeartbeatDetails(metaclass=abc.ABCMeta):
"""Details included in every heartbeat.
If an activity requires tracking progress, this should be subclassed to include
the attributes that are required for said activity. The main methods to implement
when subclassing are `deserialize_details` and `serialize_details`. Both should
deserialize from and serialize to a generic sequence or tuple, respectively.
Attributes:
_remaining: Any remaining values in the heartbeat_details tuple that we do
not parse.
"""
_remaining: collections.abc.Sequence[typing.Any]
@property
def total_details(self) -> int:
"""The total number of details that we have parsed + those remaining to parse."""
return (len(dataclasses.fields(self.__class__)) - 1) + len(self._remaining)
@classmethod
@abc.abstractmethod
def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]:
"""Deserialize `HeartbeatDetails` from a generic sequence of details.
This base class implementation just returns all details as `_remaining`.
Subclasses first call this method, and then peek into `_remaining` and
extract the details they need. For now, subclasses can only rely on the
order in which details are serialized but in the future we may need a
more robust way of identifying details.
Arguments:
details: A collection of details as returned by
`temporalio.activity.info().heartbeat_details`
"""
return {"_remaining": details}
@abc.abstractmethod
def serialize_details(self) -> tuple[typing.Any, ...]:
"""Serialize `HeartbeatDetails` to a tuple.
Since subclasses rely on the order details are serialized, subclasses
should be careful here to maintain a consistent serialization order. For
example, `_remaining` should always be placed last.
Returns:
A tuple of serialized details.
"""
return (self._remaining,)
@classmethod
def from_activity(cls, activity):
"""Instantiate this class from a Temporal Activity."""
details = activity.info().heartbeat_details
return cls.from_activity_details(details)
@classmethod
def from_activity_details(cls, details):
parsed = cls.deserialize_details(details)
return cls(**parsed)
@dataclasses.dataclass
class DataImportHeartbeatDetails(HeartbeatDetails):
"""Data import heartbeat details.
Attributes:
endpoint: The endpoint we are importing data from.
cursor: The cursor we are using to paginate through the endpoint.
"""
endpoint: str
cursor: str
@classmethod
def from_activity(cls, activity):
"""Attempt to initialize DataImportHeartbeatDetails from an activity's info."""
details = activity.info().heartbeat_details
if len(details) == 0:
raise EmptyHeartbeatError()
if len(details) != 2:
raise NotEnoughHeartbeatValuesError(len(details), 2)
return cls(endpoint=details[0], cursor=details[1], _remaining=details[2:])

View File

@ -1,149 +0,0 @@
import abc
import collections.abc
import dataclasses
import datetime as dt
import typing
class EmptyHeartbeatError(Exception):
"""Raised when an activity heartbeat is empty.
This is also the error we expect when no heartbeatting is happening, as the sequence will be empty.
"""
def __init__(self):
super().__init__(f"Heartbeat details sequence is empty")
class NotEnoughHeartbeatValuesError(Exception):
"""Raised when an activity heartbeat doesn't contain the right amount of values we expect."""
def __init__(self, details_len: int, expected: int):
super().__init__(f"Not enough values in heartbeat details (expected {expected}, got {details_len})")
class HeartbeatParseError(Exception):
"""Raised when an activity heartbeat cannot be parsed into it's expected types."""
def __init__(self, field: str):
super().__init__(f"Parsing {field} from heartbeat details encountered an error")
@dataclasses.dataclass
class HeartbeatDetails(metaclass=abc.ABCMeta):
"""The batch export details included in every heartbeat.
Each batch export destination should subclass this and implement whatever details are specific to that
batch export and required to resume it.
Attributes:
last_inserted_at: The last inserted_at we managed to upload or insert, depending on the destination.
_remaining: Any remaining values in the heartbeat_details tuple that we do not parse.
"""
_remaining: collections.abc.Sequence[typing.Any]
@property
def total_details(self) -> int:
"""The total number of details that we have parsed + those remaining to parse."""
return (len(dataclasses.fields(self.__class__)) - 1) + len(self._remaining)
@abc.abstractclassmethod
def from_activity(cls, activity):
pass
@dataclasses.dataclass
class BatchExportHeartbeatDetails(HeartbeatDetails):
last_inserted_at: dt.datetime
@classmethod
def from_activity(cls, activity):
"""Attempt to initialize HeartbeatDetails from an activity's info."""
details = activity.info().heartbeat_details
if len(details) == 0:
raise EmptyHeartbeatError()
try:
last_inserted_at = dt.datetime.fromisoformat(details[0])
except (TypeError, ValueError) as e:
raise HeartbeatParseError("last_inserted_at") from e
return cls(last_inserted_at=last_inserted_at, _remaining=details[1:])
@dataclasses.dataclass
class DataImportHeartbeatDetails(HeartbeatDetails):
"""Data import heartbeat details.
Attributes:
endpoint: The endpoint we are importing data from.
cursor: The cursor we are using to paginate through the endpoint.
"""
endpoint: str
cursor: str
@classmethod
def from_activity(cls, activity):
"""Attempt to initialize DataImportHeartbeatDetails from an activity's info."""
details = activity.info().heartbeat_details
if len(details) == 0:
raise EmptyHeartbeatError()
if len(details) != 2:
raise NotEnoughHeartbeatValuesError(len(details), 2)
return cls(endpoint=details[0], cursor=details[1], _remaining=details[2:])
HeartbeatType = typing.TypeVar("HeartbeatType", bound=HeartbeatDetails)
async def should_resume_from_activity_heartbeat(
activity, heartbeat_type: type[HeartbeatType], logger
) -> tuple[bool, HeartbeatType | None]:
"""Check if a batch export should resume from an activity's heartbeat details.
We understand that a batch export should resume any time that we receive heartbeat details and
those details can be correctly parsed. However, the decision is ultimately up to the batch export
activity to decide if it must resume and how to do so.
Returns:
A tuple with the first element indicating if the batch export should resume. If the first element
is True, the second tuple element will be the heartbeat details themselves, otherwise None.
"""
try:
heartbeat_details = heartbeat_type.from_activity(activity)
except EmptyHeartbeatError:
# We don't log this as it's the expected exception when heartbeat is empty.
heartbeat_details = None
received = False
except NotEnoughHeartbeatValuesError:
heartbeat_details = None
received = False
await logger.awarning("Details from previous activity execution did not contain the expected amount of values")
except HeartbeatParseError:
heartbeat_details = None
received = False
await logger.awarning("Details from previous activity execution could not be parsed.")
except Exception:
# We should start from the beginning, but we make a point to log unexpected errors.
# Ideally, any new exceptions should be added to the previous blocks after the first time and we will never land here.
heartbeat_details = None
received = False
await logger.aexception("Did not receive details from previous activity Execution due to an unexpected error")
else:
received = True
await logger.adebug(
f"Received details from previous activity: {heartbeat_details}",
)
return received, heartbeat_details

View File

@ -13,6 +13,7 @@ from posthog.temporal.batch_exports.batch_exports import (
RecordBatchProducerError,
RecordBatchQueue,
TaskNotDoneError,
generate_query_ranges,
get_data_interval,
iter_model_records,
iter_records,
@ -463,8 +464,8 @@ async def test_start_produce_batch_export_record_batches_uses_extra_query_parame
team_id=team_id,
is_backfill=False,
model_name="events",
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
full_range=(data_interval_start, data_interval_end),
done_ranges=[],
fields=[
{"expression": "JSONExtractInt(properties, %(hogql_val_0)s)", "alias": "custom_prop"},
],
@ -503,8 +504,8 @@ async def test_start_produce_batch_export_record_batches_can_flatten_properties(
team_id=team_id,
is_backfill=False,
model_name="events",
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
full_range=(data_interval_start, data_interval_end),
done_ranges=[],
fields=[
{"expression": "event", "alias": "event"},
{"expression": "JSONExtractString(properties, '$browser')", "alias": "browser"},
@ -560,8 +561,8 @@ async def test_start_produce_batch_export_record_batches_with_single_field_and_a
team_id=team_id,
is_backfill=False,
model_name="events",
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
full_range=(data_interval_start, data_interval_end),
done_ranges=[],
fields=[field],
extra_query_parameters={},
)
@ -615,8 +616,8 @@ async def test_start_produce_batch_export_record_batches_ignores_timestamp_predi
team_id=team_id,
is_backfill=False,
model_name="events",
interval_start=inserted_at.isoformat(),
interval_end=data_interval_end.isoformat(),
full_range=(inserted_at, data_interval_end),
done_ranges=[],
)
records = await get_all_record_batches_from_queue(queue, produce_task)
@ -629,8 +630,8 @@ async def test_start_produce_batch_export_record_batches_ignores_timestamp_predi
team_id=team_id,
is_backfill=False,
model_name="events",
interval_start=inserted_at.isoformat(),
interval_end=data_interval_end.isoformat(),
full_range=(inserted_at, data_interval_end),
done_ranges=[],
)
records = await get_all_record_batches_from_queue(queue, produce_task)
@ -664,8 +665,8 @@ async def test_start_produce_batch_export_record_batches_can_include_events(clic
team_id=team_id,
is_backfill=False,
model_name="events",
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
full_range=(data_interval_start, data_interval_end),
done_ranges=[],
include_events=include_events,
)
@ -700,8 +701,8 @@ async def test_start_produce_batch_export_record_batches_can_exclude_events(clic
team_id=team_id,
is_backfill=False,
model_name="events",
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
full_range=(data_interval_start, data_interval_end),
done_ranges=[],
exclude_events=exclude_events,
)
@ -733,8 +734,8 @@ async def test_start_produce_batch_export_record_batches_handles_duplicates(clic
team_id=team_id,
is_backfill=False,
model_name="events",
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
full_range=(data_interval_start, data_interval_end),
done_ranges=[],
)
records = await get_all_record_batches_from_queue(queue, produce_task)
@ -833,3 +834,119 @@ async def test_raise_on_produce_task_failure_does_not_raise():
await asyncio.wait([task])
await raise_on_produce_task_failure(task)
@pytest.mark.parametrize(
"remaining_range,done_ranges,expected",
[
# Case 1: One done range at the beginning
(
(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)),
[(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC))],
[
(
dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC),
)
],
),
# Case 2: Single done range equal to full range.
(
(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)),
[(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC))],
[],
),
# Case 3: Disconnected done ranges cover full range.
(
(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)),
[
(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC)),
(
dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC),
),
(
dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC),
),
],
[],
),
# Case 4: Disconnect done ranges within full range.
(
(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)),
[
(
dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC),
),
(
dt.datetime(2023, 7, 31, 12, 50, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 55, 0, tzinfo=dt.UTC),
),
],
[
(
dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC),
),
(
dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 50, 0, tzinfo=dt.UTC),
),
(
dt.datetime(2023, 7, 31, 12, 55, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC),
),
],
),
# Case 5: Empty done ranges.
(
(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)),
[],
[
(
dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC),
),
],
),
# Case 6: Disconnect done ranges within full range and one last done range connected to the end.
(
(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)),
[
(
dt.datetime(2023, 7, 31, 12, 15, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 25, 0, tzinfo=dt.UTC),
),
(
dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC),
),
(
dt.datetime(2023, 7, 31, 12, 50, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC),
),
],
[
(
dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 15, 0, tzinfo=dt.UTC),
),
(
dt.datetime(2023, 7, 31, 12, 25, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC),
),
(
dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC),
dt.datetime(2023, 7, 31, 12, 50, 0, tzinfo=dt.UTC),
),
],
),
],
ids=["1", "2", "3", "4", "5", "6"],
)
def test_generate_query_ranges(remaining_range, done_ranges, expected):
"""Test get_data_interval returns the expected data interval tuple."""
result = list(generate_query_ranges(remaining_range, done_ranges))
assert result == expected

View File

@ -27,6 +27,7 @@ from posthog.temporal.batch_exports.batch_exports import (
)
from posthog.temporal.batch_exports.bigquery_batch_export import (
BigQueryBatchExportWorkflow,
BigQueryHeartbeatDetails,
BigQueryInsertInputs,
bigquery_default_fields,
get_bigquery_fields_from_record_schema,
@ -54,14 +55,19 @@ pytestmark = [SKIP_IF_MISSING_GOOGLE_APPLICATION_CREDENTIALS, pytest.mark.asynci
TEST_TIME = dt.datetime.now(dt.UTC)
@pytest.fixture
def activity_environment(activity_environment):
activity_environment.heartbeat_class = BigQueryHeartbeatDetails
return activity_environment
async def assert_clickhouse_records_in_bigquery(
bigquery_client: bigquery.Client,
clickhouse_client: ClickHouseClient,
team_id: int,
table_id: str,
dataset_id: str,
data_interval_start: dt.datetime,
data_interval_end: dt.datetime,
date_ranges: list[tuple[dt.datetime, dt.datetime]],
min_ingested_timestamp: dt.datetime | None = None,
exclude_events: list[str] | None = None,
include_events: list[str] | None = None,
@ -69,6 +75,7 @@ async def assert_clickhouse_records_in_bigquery(
use_json_type: bool = False,
sort_key: str = "event",
is_backfill: bool = False,
expect_duplicates: bool = False,
) -> None:
"""Assert ClickHouse records are written to a given BigQuery table.
@ -78,13 +85,13 @@ async def assert_clickhouse_records_in_bigquery(
team_id: The ID of the team that we are testing for.
table_id: BigQuery table id where records are exported to.
dataset_id: BigQuery dataset containing the table where records are exported to.
data_interval_start: Start of the batch period for exported records.
data_interval_end: End of the batch period for exported records.
date_ranges: Ranges of records we should expect to have been exported.
min_ingested_timestamp: A datetime used to assert a minimum bound for 'bq_ingested_timestamp'.
exclude_events: Event names to be excluded from the export.
include_events: Event names to be included in the export.
batch_export_schema: Custom schema used in the batch export.
use_json_type: Whether to use JSON type for known fields.
expect_duplicates: Whether duplicates are expected (e.g. when testing retrying logic).
"""
if use_json_type is True:
json_columns = ["properties", "set", "set_once", "person_properties"]
@ -135,34 +142,49 @@ async def assert_clickhouse_records_in_bigquery(
]
expected_records = []
async for record_batch in iter_model_records(
client=clickhouse_client,
model=batch_export_model,
team_id=team_id,
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
exclude_events=exclude_events,
include_events=include_events,
destination_default_fields=bigquery_default_fields(),
is_backfill=is_backfill,
):
for record in record_batch.select(schema_column_names).to_pylist():
expected_record = {}
for data_interval_start, data_interval_end in date_ranges:
async for record_batch in iter_model_records(
client=clickhouse_client,
model=batch_export_model,
team_id=team_id,
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
exclude_events=exclude_events,
include_events=include_events,
destination_default_fields=bigquery_default_fields(),
is_backfill=is_backfill,
):
for record in record_batch.select(schema_column_names).to_pylist():
expected_record = {}
for k, v in record.items():
if k not in schema_column_names or k == "_inserted_at" or k == "bq_ingested_timestamp":
# _inserted_at is not exported, only used for tracking progress.
# bq_ingested_timestamp cannot be compared as it comes from an unstable function.
continue
for k, v in record.items():
if k not in schema_column_names or k == "_inserted_at" or k == "bq_ingested_timestamp":
# _inserted_at is not exported, only used for tracking progress.
# bq_ingested_timestamp cannot be compared as it comes from an unstable function.
continue
if k in json_columns and v is not None:
expected_record[k] = json.loads(v)
elif isinstance(v, dt.datetime):
expected_record[k] = v.replace(tzinfo=dt.UTC)
else:
expected_record[k] = v
if k in json_columns and v is not None:
expected_record[k] = json.loads(v)
elif isinstance(v, dt.datetime):
expected_record[k] = v.replace(tzinfo=dt.UTC)
else:
expected_record[k] = v
expected_records.append(expected_record)
expected_records.append(expected_record)
if expect_duplicates:
seen = set()
def is_record_seen(record) -> bool:
nonlocal seen
if record["uuid"] in seen:
return True
seen.add(record["uuid"])
return False
inserted_records = [record for record in inserted_records if not is_record_seen(record)]
assert len(inserted_records) == len(expected_records)
@ -328,8 +350,7 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table(
table_id=f"test_insert_activity_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
team_id=ateam.pk,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
date_ranges=[(data_interval_start, data_interval_end)],
exclude_events=exclude_events,
include_events=None,
batch_export_model=model,
@ -382,8 +403,7 @@ async def test_insert_into_bigquery_activity_merges_data_in_follow_up_runs(
table_id=f"test_insert_activity_mutability_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
team_id=ateam.pk,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
date_ranges=[(data_interval_start, data_interval_end)],
batch_export_model=model,
min_ingested_timestamp=ingested_timestamp,
sort_key="person_id",
@ -423,14 +443,235 @@ async def test_insert_into_bigquery_activity_merges_data_in_follow_up_runs(
table_id=f"test_insert_activity_mutability_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
team_id=ateam.pk,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
date_ranges=[(data_interval_start, data_interval_end)],
batch_export_model=model,
min_ingested_timestamp=ingested_timestamp,
sort_key="person_id",
)
@pytest.mark.parametrize("interval", ["hour"], indirect=True)
@pytest.mark.parametrize(
"done_relative_ranges,expected_relative_ranges",
[
(
[(dt.timedelta(minutes=0), dt.timedelta(minutes=15))],
[(dt.timedelta(minutes=15), dt.timedelta(minutes=60))],
),
(
[
(dt.timedelta(minutes=10), dt.timedelta(minutes=15)),
(dt.timedelta(minutes=35), dt.timedelta(minutes=45)),
],
[
(dt.timedelta(minutes=0), dt.timedelta(minutes=10)),
(dt.timedelta(minutes=15), dt.timedelta(minutes=35)),
(dt.timedelta(minutes=45), dt.timedelta(minutes=60)),
],
),
(
[
(dt.timedelta(minutes=45), dt.timedelta(minutes=60)),
],
[
(dt.timedelta(minutes=0), dt.timedelta(minutes=45)),
],
),
],
)
async def test_insert_into_bigquery_activity_resumes_from_heartbeat(
clickhouse_client,
activity_environment,
bigquery_client,
bigquery_config,
bigquery_dataset,
generate_test_data,
data_interval_start,
data_interval_end,
ateam,
done_relative_ranges,
expected_relative_ranges,
):
"""Test we insert partial data into a BigQuery table when resuming.
After an activity runs, heartbeats, and crashes, a follow-up activity should
pick-up from where the first one left. This capability is critical to ensure
long-running activities that export a lot of data will eventually finish.
"""
batch_export_model = BatchExportModel(name="events", schema=None)
insert_inputs = BigQueryInsertInputs(
team_id=ateam.pk,
table_id=f"test_insert_activity_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
data_interval_start=data_interval_start.isoformat(),
data_interval_end=data_interval_end.isoformat(),
use_json_type=True,
batch_export_model=batch_export_model,
**bigquery_config,
)
now = dt.datetime.now(tz=dt.UTC)
done_ranges = [
(
(data_interval_start + done_relative_range[0]).isoformat(),
(data_interval_start + done_relative_range[1]).isoformat(),
)
for done_relative_range in done_relative_ranges
]
expected_ranges = [
(
(data_interval_start + expected_relative_range[0]),
(data_interval_start + expected_relative_range[1]),
)
for expected_relative_range in expected_relative_ranges
]
workflow_id = uuid.uuid4()
fake_info = activity.Info(
activity_id="insert-into-bigquery-activity",
activity_type="unknown",
current_attempt_scheduled_time=dt.datetime.now(dt.UTC),
workflow_id=str(workflow_id),
workflow_type="bigquery-export",
workflow_run_id=str(uuid.uuid4()),
attempt=1,
heartbeat_timeout=dt.timedelta(seconds=1),
heartbeat_details=[done_ranges],
is_local=False,
schedule_to_close_timeout=dt.timedelta(seconds=10),
scheduled_time=dt.datetime.now(dt.UTC),
start_to_close_timeout=dt.timedelta(seconds=20),
started_time=dt.datetime.now(dt.UTC),
task_queue="test",
task_token=b"test",
workflow_namespace="default",
)
activity_environment.info = fake_info
await activity_environment.run(insert_into_bigquery_activity, insert_inputs)
await assert_clickhouse_records_in_bigquery(
bigquery_client=bigquery_client,
clickhouse_client=clickhouse_client,
table_id=f"test_insert_activity_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
team_id=ateam.pk,
date_ranges=expected_ranges,
include_events=None,
batch_export_model=batch_export_model,
use_json_type=True,
min_ingested_timestamp=now,
sort_key="event",
)
async def test_insert_into_bigquery_activity_completes_range(
clickhouse_client,
activity_environment,
bigquery_client,
bigquery_config,
bigquery_dataset,
generate_test_data,
data_interval_start,
data_interval_end,
ateam,
):
"""Test we complete a full range of data into a BigQuery table when resuming.
We run two activities:
1. First activity, up to (and including) the cutoff event.
2. Second activity with a heartbeat detail matching the cutoff event.
This simulates the batch export resuming from a failed execution. The full range
should be completed (with a duplicate on the cutoff event) after both activities
are done.
"""
batch_export_model = BatchExportModel(name="events", schema=None)
now = dt.datetime.now(tz=dt.UTC)
events_to_export_created, _ = generate_test_data
events_to_export_created.sort(key=operator.itemgetter("inserted_at"))
cutoff_event = events_to_export_created[len(events_to_export_created) // 2 : len(events_to_export_created) // 2 + 1]
assert len(cutoff_event) == 1
cutoff_event = cutoff_event[0]
cutoff_data_interval_end = dt.datetime.fromisoformat(cutoff_event["inserted_at"]).replace(tzinfo=dt.UTC)
insert_inputs = BigQueryInsertInputs(
team_id=ateam.pk,
table_id=f"test_insert_activity_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
data_interval_start=data_interval_start.isoformat(),
# The extra second is because the upper range select is exclusive and
# we want cutoff to be the last event included.
data_interval_end=(cutoff_data_interval_end + dt.timedelta(seconds=1)).isoformat(),
use_json_type=True,
batch_export_model=batch_export_model,
**bigquery_config,
)
await activity_environment.run(insert_into_bigquery_activity, insert_inputs)
done_ranges = [
(
data_interval_start.isoformat(),
cutoff_data_interval_end.isoformat(),
),
]
workflow_id = uuid.uuid4()
fake_info = activity.Info(
activity_id="insert-into-bigquery-activity",
activity_type="unknown",
current_attempt_scheduled_time=dt.datetime.now(dt.UTC),
workflow_id=str(workflow_id),
workflow_type="bigquery-export",
workflow_run_id=str(uuid.uuid4()),
attempt=1,
heartbeat_timeout=dt.timedelta(seconds=1),
heartbeat_details=[done_ranges],
is_local=False,
schedule_to_close_timeout=dt.timedelta(seconds=10),
scheduled_time=dt.datetime.now(dt.UTC),
start_to_close_timeout=dt.timedelta(seconds=20),
started_time=dt.datetime.now(dt.UTC),
task_queue="test",
task_token=b"test",
workflow_namespace="default",
)
activity_environment.info = fake_info
insert_inputs = BigQueryInsertInputs(
team_id=ateam.pk,
table_id=f"test_insert_activity_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
data_interval_start=data_interval_start.isoformat(),
data_interval_end=data_interval_end.isoformat(),
use_json_type=True,
batch_export_model=batch_export_model,
**bigquery_config,
)
await activity_environment.run(insert_into_bigquery_activity, insert_inputs)
await assert_clickhouse_records_in_bigquery(
bigquery_client=bigquery_client,
clickhouse_client=clickhouse_client,
table_id=f"test_insert_activity_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
team_id=ateam.pk,
date_ranges=[(data_interval_start, data_interval_end)],
include_events=None,
batch_export_model=batch_export_model,
use_json_type=True,
min_ingested_timestamp=now,
sort_key="event",
expect_duplicates=True,
)
@pytest.fixture
def table_id(ateam, interval):
return f"test_workflow_table_{ateam.pk}_{interval}"
@ -532,7 +773,7 @@ async def test_bigquery_export_workflow(
id=workflow_id,
task_queue=BATCH_EXPORTS_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
execution_timeout=dt.timedelta(seconds=30),
)
runs = await afetch_batch_export_runs(batch_export_id=bigquery_batch_export.id)
@ -552,8 +793,7 @@ async def test_bigquery_export_workflow(
table_id=table_id,
dataset_id=bigquery_batch_export.destination.config["dataset_id"],
team_id=ateam.pk,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
date_ranges=[(data_interval_start, data_interval_end)],
exclude_events=exclude_events,
include_events=None,
batch_export_model=model,
@ -715,8 +955,7 @@ async def test_bigquery_export_workflow_backfill_earliest_persons(
table_id=table_id,
dataset_id=bigquery_batch_export.destination.config["dataset_id"],
team_id=ateam.pk,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
date_ranges=[(data_interval_start, data_interval_end)],
batch_export_model=model,
use_json_type=use_json_type,
sort_key="person_id",
@ -759,6 +998,7 @@ async def test_bigquery_export_workflow_handles_insert_activity_errors(ateam, bi
id=workflow_id,
task_queue=BATCH_EXPORTS_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=20),
)
runs = await afetch_batch_export_runs(batch_export_id=bigquery_batch_export.id)

View File

@ -0,0 +1,104 @@
import datetime as dt
import pytest
from posthog.temporal.batch_exports.heartbeat import BatchExportRangeHeartbeatDetails
@pytest.mark.parametrize(
"initial_done_ranges,done_range,expected_index",
[
# Case 1: Inserting into an empty initial list.
([], (dt.datetime.fromtimestamp(5), dt.datetime.fromtimestamp(6)), 0),
# Case 2: Inserting into middle of initial list.
(
[
(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)),
(dt.datetime.fromtimestamp(6), dt.datetime.fromtimestamp(10)),
],
(dt.datetime.fromtimestamp(5), dt.datetime.fromtimestamp(6)),
1,
),
# Case 3: Inserting into beginning of initial list.
(
[
(dt.datetime.fromtimestamp(1), dt.datetime.fromtimestamp(5)),
(dt.datetime.fromtimestamp(6), dt.datetime.fromtimestamp(10)),
],
(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(1)),
0,
),
# Case 4: Inserting into end of initial list.
(
[(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(10))],
(dt.datetime.fromtimestamp(10), dt.datetime.fromtimestamp(11)),
1,
),
# Case 5: Inserting disconnected range into middle of initial list.
(
[
(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(10)),
(dt.datetime.fromtimestamp(15), dt.datetime.fromtimestamp(20)),
],
(dt.datetime.fromtimestamp(12), dt.datetime.fromtimestamp(13)),
1,
),
],
)
def test_insert_done_range(initial_done_ranges, done_range, expected_index):
"""Test `BatchExportRangeHeartbeatDetails` inserts a done range in the expected index.
We avoid merging ranges to maintain the original index so we can assert it matches
the expected index.
"""
heartbeat_details = BatchExportRangeHeartbeatDetails()
heartbeat_details.done_ranges.extend(initial_done_ranges)
heartbeat_details.insert_done_range(done_range, merge=False)
assert len(heartbeat_details.done_ranges) == len(initial_done_ranges) + 1
assert heartbeat_details.done_ranges.index(done_range) == expected_index
@pytest.mark.parametrize(
"initial_done_ranges,expected_done_ranges",
[
# Case 1: Disconnected ranges are not merged.
(
[
(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)),
(dt.datetime.fromtimestamp(6), dt.datetime.fromtimestamp(10)),
],
[
(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)),
(dt.datetime.fromtimestamp(6), dt.datetime.fromtimestamp(10)),
],
),
# Case 2: Connected ranges are merged.
(
[
(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)),
(dt.datetime.fromtimestamp(5), dt.datetime.fromtimestamp(10)),
],
[(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(10))],
),
# Case 3: Connected ranges are merged, but disconnected are not.
(
[
(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)),
(dt.datetime.fromtimestamp(5), dt.datetime.fromtimestamp(10)),
(dt.datetime.fromtimestamp(11), dt.datetime.fromtimestamp(12)),
],
[
(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(10)),
(dt.datetime.fromtimestamp(11), dt.datetime.fromtimestamp(12)),
],
),
],
)
def test_merge_done_ranges(initial_done_ranges, expected_done_ranges):
"""Test `BatchExportRangeHeartbeatDetails` merges done ranges."""
heartbeat_details = BatchExportRangeHeartbeatDetails()
heartbeat_details.done_ranges.extend(initial_done_ranges)
heartbeat_details.merge_done_ranges()
assert heartbeat_details.done_ranges == expected_done_ranges

View File

@ -3,7 +3,7 @@ import json
import operator
import os
import warnings
from uuid import uuid4
import uuid
import psycopg
import pytest
@ -59,13 +59,13 @@ async def assert_clickhouse_records_in_redshfit(
table_name: str,
team_id: int,
batch_export_model: BatchExportModel | BatchExportSchema | None,
data_interval_start: dt.datetime,
data_interval_end: dt.datetime,
date_ranges: list[tuple[dt.datetime, dt.datetime]],
exclude_events: list[str] | None = None,
include_events: list[str] | None = None,
properties_data_type: str = "varchar",
sort_key: str = "event",
is_backfill: bool = False,
expected_duplicates_threshold: float = 0.0,
):
"""Assert expected records are written to a given Redshift table.
@ -89,6 +89,9 @@ async def assert_clickhouse_records_in_redshfit(
table_name: Redshift table name.
team_id: The ID of the team that we are testing events for.
batch_export_schema: Custom schema used in the batch export.
date_ranges: Ranges of records we should expect to have been exported.
expected_duplicates_threshold: Threshold of duplicates we should expect relative to
number of unique events, fail if we exceed it.
"""
super_columns = ["properties", "set", "set_once", "person_properties"]
@ -132,33 +135,49 @@ async def assert_clickhouse_records_in_redshfit(
]
expected_records = []
async for record_batch in iter_model_records(
client=clickhouse_client,
model=batch_export_model,
team_id=team_id,
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
exclude_events=exclude_events,
include_events=include_events,
destination_default_fields=redshift_default_fields(),
is_backfill=is_backfill,
):
for record in record_batch.select(schema_column_names).to_pylist():
expected_record = {}
for data_interval_start, data_interval_end in date_ranges:
async for record_batch in iter_model_records(
client=clickhouse_client,
model=batch_export_model,
team_id=team_id,
interval_start=data_interval_start.isoformat(),
interval_end=data_interval_end.isoformat(),
exclude_events=exclude_events,
include_events=include_events,
destination_default_fields=redshift_default_fields(),
is_backfill=is_backfill,
):
for record in record_batch.select(schema_column_names).to_pylist():
expected_record = {}
for k, v in record.items():
if k not in schema_column_names or k == "_inserted_at":
# _inserted_at is not exported, only used for tracking progress.
continue
for k, v in record.items():
if k not in schema_column_names or k == "_inserted_at":
# _inserted_at is not exported, only used for tracking progress.
continue
elif k in super_columns and v is not None:
expected_record[k] = remove_escaped_whitespace_recursive(json.loads(v))
elif isinstance(v, dt.datetime):
expected_record[k] = v.replace(tzinfo=dt.UTC)
else:
expected_record[k] = v
elif k in super_columns and v is not None:
expected_record[k] = remove_escaped_whitespace_recursive(json.loads(v))
elif isinstance(v, dt.datetime):
expected_record[k] = v.replace(tzinfo=dt.UTC)
else:
expected_record[k] = v
expected_records.append(expected_record)
expected_records.append(expected_record)
if expected_duplicates_threshold > 0.0:
seen = set()
def is_record_seen(record) -> bool:
nonlocal seen
if record["uuid"] in seen:
return True
seen.add(record["uuid"])
return False
inserted_records = [record for record in inserted_records if not is_record_seen(record)]
unduplicated_len = len(inserted_records)
assert (unduplicated_len - len(inserted_records)) / len(inserted_records) < expected_duplicates_threshold
inserted_column_names = list(inserted_records[0].keys())
expected_column_names = list(expected_records[0].keys())
@ -171,6 +190,7 @@ async def assert_clickhouse_records_in_redshfit(
assert inserted_column_names == expected_column_names
assert inserted_records[0] == expected_records[0]
assert inserted_records == expected_records
assert len(inserted_records) == len(expected_records)
@pytest.fixture
@ -348,8 +368,7 @@ async def test_insert_into_redshift_activity_inserts_data_into_redshift_table(
schema_name=redshift_config["schema"],
table_name=table_name,
team_id=ateam.pk,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
date_ranges=[(data_interval_start, data_interval_end)],
batch_export_model=model,
exclude_events=exclude_events,
properties_data_type=properties_data_type,
@ -357,6 +376,227 @@ async def test_insert_into_redshift_activity_inserts_data_into_redshift_table(
)
@pytest.mark.parametrize("interval", ["hour"], indirect=True)
@pytest.mark.parametrize(
"done_relative_ranges,expected_relative_ranges",
[
(
[(dt.timedelta(minutes=0), dt.timedelta(minutes=15))],
[(dt.timedelta(minutes=15), dt.timedelta(minutes=60))],
),
(
[
(dt.timedelta(minutes=10), dt.timedelta(minutes=15)),
(dt.timedelta(minutes=35), dt.timedelta(minutes=45)),
],
[
(dt.timedelta(minutes=0), dt.timedelta(minutes=10)),
(dt.timedelta(minutes=15), dt.timedelta(minutes=35)),
(dt.timedelta(minutes=45), dt.timedelta(minutes=60)),
],
),
(
[
(dt.timedelta(minutes=45), dt.timedelta(minutes=60)),
],
[
(dt.timedelta(minutes=0), dt.timedelta(minutes=45)),
],
),
],
)
async def test_insert_into_bigquery_activity_resumes_from_heartbeat(
clickhouse_client,
activity_environment,
psycopg_connection,
redshift_config,
exclude_events,
generate_test_data,
data_interval_start,
data_interval_end,
ateam,
done_relative_ranges,
expected_relative_ranges,
):
"""Test we insert partial data into a BigQuery table when resuming.
After an activity runs, heartbeats, and crashes, a follow-up activity should
pick-up from where the first one left. This capability is critical to ensure
long-running activities that export a lot of data will eventually finish.
"""
batch_export_model = BatchExportModel(name="events", schema=None)
properties_data_type = "varchar"
insert_inputs = RedshiftInsertInputs(
team_id=ateam.pk,
table_name=f"test_insert_activity_table_{ateam.pk}",
data_interval_start=data_interval_start.isoformat(),
data_interval_end=data_interval_end.isoformat(),
exclude_events=exclude_events,
batch_export_model=batch_export_model,
properties_data_type=properties_data_type,
**redshift_config,
)
done_ranges = [
(
(data_interval_start + done_relative_range[0]).isoformat(),
(data_interval_start + done_relative_range[1]).isoformat(),
)
for done_relative_range in done_relative_ranges
]
expected_ranges = [
(
(data_interval_start + expected_relative_range[0]),
(data_interval_start + expected_relative_range[1]),
)
for expected_relative_range in expected_relative_ranges
]
workflow_id = uuid.uuid4()
fake_info = activity.Info(
activity_id="insert-into-redshift-activity",
activity_type="unknown",
current_attempt_scheduled_time=dt.datetime.now(dt.UTC),
workflow_id=str(workflow_id),
workflow_type="redshift-export",
workflow_run_id=str(uuid.uuid4()),
attempt=1,
heartbeat_timeout=dt.timedelta(seconds=1),
heartbeat_details=[done_ranges],
is_local=False,
schedule_to_close_timeout=dt.timedelta(seconds=10),
scheduled_time=dt.datetime.now(dt.UTC),
start_to_close_timeout=dt.timedelta(seconds=20),
started_time=dt.datetime.now(dt.UTC),
task_queue="test",
task_token=b"test",
workflow_namespace="default",
)
activity_environment.info = fake_info
await activity_environment.run(insert_into_redshift_activity, insert_inputs)
await assert_clickhouse_records_in_redshfit(
redshift_connection=psycopg_connection,
clickhouse_client=clickhouse_client,
schema_name=redshift_config["schema"],
table_name=f"test_insert_activity_table_{ateam.pk}",
team_id=ateam.pk,
date_ranges=expected_ranges,
batch_export_model=batch_export_model,
exclude_events=exclude_events,
properties_data_type=properties_data_type,
sort_key="event",
expected_duplicates_threshold=0.1,
)
async def test_insert_into_redshift_activity_completes_range(
clickhouse_client,
activity_environment,
psycopg_connection,
redshift_config,
exclude_events,
generate_test_data,
data_interval_start,
data_interval_end,
ateam,
):
"""Test we complete a full range of data into a Redshift table when resuming.
We run two activities:
1. First activity, up to (and including) the cutoff event.
2. Second activity with a heartbeat detail matching the cutoff event.
This simulates the batch export resuming from a failed execution. The full range
should be completed (with a duplicate on the cutoff event) after both activities
are done.
"""
batch_export_model = BatchExportModel(name="events", schema=None)
properties_data_type = "varchar"
events_to_export_created, _ = generate_test_data
events_to_export_created.sort(key=operator.itemgetter("inserted_at"))
cutoff_event = events_to_export_created[len(events_to_export_created) // 2 : len(events_to_export_created) // 2 + 1]
assert len(cutoff_event) == 1
cutoff_event = cutoff_event[0]
cutoff_data_interval_end = dt.datetime.fromisoformat(cutoff_event["inserted_at"]).replace(tzinfo=dt.UTC)
insert_inputs = RedshiftInsertInputs(
team_id=ateam.pk,
table_name=f"test_insert_activity_table_{ateam.pk}",
data_interval_start=data_interval_start.isoformat(),
# The extra second is because the upper range select is exclusive and
# we want cutoff to be the last event included.
data_interval_end=(cutoff_data_interval_end + dt.timedelta(seconds=1)).isoformat(),
exclude_events=exclude_events,
batch_export_model=batch_export_model,
properties_data_type=properties_data_type,
**redshift_config,
)
await activity_environment.run(insert_into_redshift_activity, insert_inputs)
done_ranges = [
(
data_interval_start.isoformat(),
cutoff_data_interval_end.isoformat(),
),
]
workflow_id = uuid.uuid4()
fake_info = activity.Info(
activity_id="insert-into-bigquery-activity",
activity_type="unknown",
current_attempt_scheduled_time=dt.datetime.now(dt.UTC),
workflow_id=str(workflow_id),
workflow_type="bigquery-export",
workflow_run_id=str(uuid.uuid4()),
attempt=1,
heartbeat_timeout=dt.timedelta(seconds=1),
heartbeat_details=[done_ranges],
is_local=False,
schedule_to_close_timeout=dt.timedelta(seconds=10),
scheduled_time=dt.datetime.now(dt.UTC),
start_to_close_timeout=dt.timedelta(seconds=20),
started_time=dt.datetime.now(dt.UTC),
task_queue="test",
task_token=b"test",
workflow_namespace="default",
)
activity_environment.info = fake_info
insert_inputs = RedshiftInsertInputs(
team_id=ateam.pk,
table_name=f"test_insert_activity_table_{ateam.pk}",
data_interval_start=data_interval_start.isoformat(),
data_interval_end=data_interval_end.isoformat(),
exclude_events=exclude_events,
batch_export_model=batch_export_model,
properties_data_type=properties_data_type,
**redshift_config,
)
await activity_environment.run(insert_into_redshift_activity, insert_inputs)
await assert_clickhouse_records_in_redshfit(
redshift_connection=psycopg_connection,
clickhouse_client=clickhouse_client,
schema_name=redshift_config["schema"],
table_name=f"test_insert_activity_table_{ateam.pk}",
team_id=ateam.pk,
date_ranges=[(data_interval_start, data_interval_end)],
batch_export_model=batch_export_model,
exclude_events=exclude_events,
properties_data_type=properties_data_type,
sort_key="event",
expected_duplicates_threshold=0.1,
)
@pytest.fixture
def table_name(ateam, interval):
return f"test_workflow_table_{ateam.pk}_{interval}"
@ -421,7 +661,7 @@ async def test_redshift_export_workflow(
elif model is not None:
batch_export_schema = model
workflow_id = str(uuid4())
workflow_id = str(uuid.uuid4())
inputs = RedshiftBatchExportInputs(
team_id=ateam.pk,
batch_export_id=str(redshift_batch_export.id),
@ -470,8 +710,7 @@ async def test_redshift_export_workflow(
schema_name=redshift_config["schema"],
table_name=table_name,
team_id=ateam.pk,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
date_ranges=[(data_interval_start, data_interval_end)],
batch_export_model=model,
exclude_events=exclude_events,
sort_key="person_id" if batch_export_model is not None and batch_export_model.name == "persons" else "event",
@ -495,11 +734,13 @@ def test_remove_escaped_whitespace_recursive(value, expected):
assert remove_escaped_whitespace_recursive(value) == expected
async def test_redshift_export_workflow_handles_insert_activity_errors(ateam, redshift_batch_export, interval):
async def test_redshift_export_workflow_handles_insert_activity_errors(
event_loop, ateam, redshift_batch_export, interval
):
"""Test that Redshift Export Workflow can gracefully handle errors when inserting Redshift data."""
data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00")
workflow_id = str(uuid4())
workflow_id = str(uuid.uuid4())
inputs = RedshiftBatchExportInputs(
team_id=ateam.pk,
batch_export_id=str(redshift_batch_export.id),
@ -531,6 +772,7 @@ async def test_redshift_export_workflow_handles_insert_activity_errors(ateam, re
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=20),
)
runs = await afetch_batch_export_runs(batch_export_id=redshift_batch_export.id)
@ -548,7 +790,7 @@ async def test_redshift_export_workflow_handles_insert_activity_non_retryable_er
"""Test that Redshift Export Workflow can gracefully handle non-retryable errors when inserting Redshift data."""
data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00")
workflow_id = str(uuid4())
workflow_id = str(uuid.uuid4())
inputs = RedshiftBatchExportInputs(
team_id=ateam.pk,
batch_export_id=str(redshift_batch_export.id),

View File

@ -28,7 +28,7 @@ from posthog.temporal.batch_exports.batch_exports import (
)
from posthog.temporal.batch_exports.s3_batch_export import (
FILE_FORMAT_EXTENSIONS,
HeartbeatDetails,
S3HeartbeatDetails,
IntermittentUploadPartTimeoutError,
S3BatchExportInputs,
S3BatchExportWorkflow,
@ -1010,6 +1010,12 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix(
)
class RetryableTestException(Exception):
"""An exception to be raised during tests"""
pass
async def test_s3_export_workflow_handles_insert_activity_errors(ateam, s3_batch_export, interval):
"""Test S3BatchExport Workflow can handle errors from executing the insert into S3 activity.
@ -1028,7 +1034,7 @@ async def test_s3_export_workflow_handles_insert_activity_errors(ateam, s3_batch
@activity.defn(name="insert_into_s3_activity")
async def insert_into_s3_activity_mocked(_: S3InsertInputs) -> str:
raise ValueError("A useful error message")
raise RetryableTestException("A useful error message")
async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
@ -1056,7 +1062,7 @@ async def test_s3_export_workflow_handles_insert_activity_errors(ateam, s3_batch
run = runs[0]
assert run.status == "FailedRetryable"
assert run.latest_error == "ValueError: A useful error message"
assert run.latest_error == "RetryableTestException: A useful error message"
assert run.records_completed is None
@ -1387,14 +1393,14 @@ async def test_insert_into_s3_activity_heartbeats(
inserted_at=part_inserted_at,
)
heartbeat_details = []
heartbeat_details: list[S3HeartbeatDetails] = []
def track_hearbeat_details(*details):
"""Record heartbeat details received."""
nonlocal heartbeat_details
details = HeartbeatDetails.from_activity_details(details)
heartbeat_details.append(details)
s3_details = S3HeartbeatDetails.from_activity_details(details)
heartbeat_details.append(s3_details)
activity_environment.on_heartbeat = track_hearbeat_details
@ -1415,11 +1421,13 @@ async def test_insert_into_s3_activity_heartbeats(
assert len(heartbeat_details) > 0
for detail in heartbeat_details:
last_uploaded_part_dt = dt.datetime.fromisoformat(detail.last_uploaded_part_timestamp)
assert last_uploaded_part_dt == data_interval_end - s3_batch_export.interval_time_delta / len(
detail.upload_state.parts
)
detail = heartbeat_details[-1]
assert detail.upload_state is not None
assert len(detail.upload_state.parts) == 3
assert len(detail.done_ranges) == 1
assert detail.done_ranges[0] == (data_interval_start, data_interval_end)
await assert_clickhouse_records_in_s3(
s3_compatible_client=minio_client,

View File

@ -1631,7 +1631,13 @@ async def test_insert_into_snowflake_activity_heartbeats(
)
@pytest.mark.parametrize("details", [(str(dt.datetime.now()), 1)])
@pytest.mark.parametrize(
"details",
[
([(dt.datetime.now().isoformat(), dt.datetime.now().isoformat())], 1),
([(dt.datetime.now().isoformat(), dt.datetime.now().isoformat())],),
],
)
def test_snowflake_heartbeat_details_parses_from_tuple(details):
class FakeActivity:
def info(self):
@ -1642,6 +1648,16 @@ def test_snowflake_heartbeat_details_parses_from_tuple(details):
self.heartbeat_details = details
snowflake_details = SnowflakeHeartbeatDetails.from_activity(FakeActivity())
expected_done_ranges = details[0]
assert snowflake_details.last_inserted_at == dt.datetime.fromisoformat(details[0])
assert snowflake_details.file_no == details[1]
assert snowflake_details.done_ranges == [
(
dt.datetime.fromisoformat(expected_done_ranges[0][0]),
dt.datetime.fromisoformat(expected_done_ranges[0][1]),
)
]
if len(details) >= 2:
assert snowflake_details.file_no == details[1]
else:
assert snowflake_details.file_no == 0

View File

@ -11,7 +11,7 @@ from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
CSVBatchExportWriter,
JSONLBatchExportWriter,
LastInsertedAt,
DateRange,
ParquetBatchExportWriter,
json_dumps_bytes,
)
@ -209,7 +209,9 @@ TEST_RECORD_BATCHES = [
{
"event": pa.array(["test-event-0", "test-event-1", "test-event-2"]),
"properties": pa.array(['{"prop_0": 1, "prop_1": 2}', "{}", "null"]),
"_inserted_at": pa.array([0, 1, 2]),
"_inserted_at": pa.array(
[dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(1), dt.datetime.fromtimestamp(2)]
),
}
)
]
@ -223,20 +225,20 @@ TEST_RECORD_BATCHES = [
async def test_jsonl_writer_writes_record_batches(record_batch):
"""Test record batches are written as valid JSONL."""
in_memory_file_obj = io.BytesIO()
inserted_ats_seen: list[LastInsertedAt] = []
date_ranges_seen: list[DateRange] = []
async def store_in_memory_on_flush(
batch_export_file,
records_since_last_flush,
bytes_since_last_flush,
flush_counter,
last_inserted_at,
last_date_range,
is_last,
error,
):
assert writer.records_since_last_flush == record_batch.num_rows
in_memory_file_obj.write(batch_export_file.read())
inserted_ats_seen.append(last_inserted_at)
date_ranges_seen.append(last_date_range)
writer = JSONLBatchExportWriter(max_bytes=1, flush_callable=store_in_memory_on_flush)
@ -257,7 +259,9 @@ async def test_jsonl_writer_writes_record_batches(record_batch):
assert "_inserted_at" not in written_jsonl
assert written_jsonl == {k: v for k, v in expected_jsonl.items() if k != "_inserted_at"}
assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()]
assert date_ranges_seen == [
(record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py())
]
@pytest.mark.parametrize(
@ -268,19 +272,19 @@ async def test_jsonl_writer_writes_record_batches(record_batch):
async def test_csv_writer_writes_record_batches(record_batch):
"""Test record batches are written as valid CSV."""
in_memory_file_obj = io.StringIO()
inserted_ats_seen = []
date_ranges_seen: list[DateRange] = []
async def store_in_memory_on_flush(
batch_export_file,
records_since_last_flush,
bytes_since_last_flush,
flush_counter,
last_inserted_at,
last_date_range,
is_last,
error,
):
in_memory_file_obj.write(batch_export_file.read().decode("utf-8"))
inserted_ats_seen.append(last_inserted_at)
date_ranges_seen.append(last_date_range)
schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"]
writer = CSVBatchExportWriter(max_bytes=1, field_names=schema_columns, flush_callable=store_in_memory_on_flush)
@ -304,7 +308,9 @@ async def test_csv_writer_writes_record_batches(record_batch):
assert "_inserted_at" not in written_csv_row
assert written_csv_row == list({k: v for k, v in expected_dict.items() if k != "_inserted_at"}.values())
assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()]
assert date_ranges_seen == [
(record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py())
]
@pytest.mark.parametrize(
@ -315,19 +321,19 @@ async def test_csv_writer_writes_record_batches(record_batch):
async def test_parquet_writer_writes_record_batches(record_batch):
"""Test record batches are written as valid Parquet."""
in_memory_file_obj = io.BytesIO()
inserted_ats_seen = []
date_ranges_seen: list[DateRange] = []
async def store_in_memory_on_flush(
batch_export_file,
records_since_last_flush,
bytes_since_last_flush,
flush_counter,
last_inserted_at,
last_date_range,
is_last,
error,
):
in_memory_file_obj.write(batch_export_file.read())
inserted_ats_seen.append(last_inserted_at)
date_ranges_seen.append(last_date_range)
schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"]
@ -353,9 +359,9 @@ async def test_parquet_writer_writes_record_batches(record_batch):
# NOTE: Parquet gets flushed twice due to the extra flush at the end for footer bytes, so our mock function
# will see this value twice.
assert inserted_ats_seen == [
record_batch.column("_inserted_at")[-1].as_py(),
record_batch.column("_inserted_at")[-1].as_py(),
assert date_ranges_seen == [
(record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py()),
(record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py()),
]
@ -412,7 +418,7 @@ async def test_flushing_parquet_writer_resets_underlying_file(record_batch):
assert writer.bytes_since_last_flush == writer.batch_export_file.bytes_since_last_reset
assert writer.records_since_last_flush == record_batch.num_rows
await writer.flush(dt.datetime.now())
await writer.flush()
assert flush_counter == 1
assert writer.batch_export_file.tell() == 0
@ -427,7 +433,7 @@ async def test_flushing_parquet_writer_resets_underlying_file(record_batch):
async def test_jsonl_writer_deals_with_web_vitals():
"""Test old $web_vitals record batches are written as valid JSONL."""
in_memory_file_obj = io.BytesIO()
inserted_ats_seen: list[LastInsertedAt] = []
date_ranges_seen: list[DateRange] = []
record_batch = pa.RecordBatch.from_pydict(
{
@ -442,7 +448,7 @@ async def test_jsonl_writer_deals_with_web_vitals():
}
]
),
"_inserted_at": pa.array([0]),
"_inserted_at": pa.array([dt.datetime.fromtimestamp(0)]),
}
)
@ -451,13 +457,13 @@ async def test_jsonl_writer_deals_with_web_vitals():
records_since_last_flush,
bytes_since_last_flush,
flush_counter,
last_inserted_at,
last_date_range,
is_last,
error,
):
assert writer.records_since_last_flush == record_batch.num_rows
in_memory_file_obj.write(batch_export_file.read())
inserted_ats_seen.append(last_inserted_at)
date_ranges_seen.append(last_date_range)
writer = JSONLBatchExportWriter(max_bytes=1, flush_callable=store_in_memory_on_flush)
@ -479,20 +485,22 @@ async def test_jsonl_writer_deals_with_web_vitals():
del expected_jsonl["properties"]["$web_vitals_INP_event"]["attribution"]["interactionTargetElement"]
assert written_jsonl == {k: v for k, v in expected_jsonl.items() if k != "_inserted_at"}
assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()]
assert date_ranges_seen == [
(record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py())
]
@pytest.mark.asyncio
async def test_jsonl_writer_deals_with_nested_user_events():
"""Test very nested user event record batches are written as valid JSONL."""
in_memory_file_obj = io.BytesIO()
inserted_ats_seen: list[LastInsertedAt] = []
date_ranges_seen: list[DateRange] = []
record_batch = pa.RecordBatch.from_pydict(
{
"event": pa.array(["my_event"]),
"properties": pa.array([{"we_have_to_go_deeper": json.loads("[" * 256 + "]" * 256)}]),
"_inserted_at": pa.array([0]),
"_inserted_at": pa.array([dt.datetime.fromtimestamp(0)]),
}
)
@ -501,13 +509,13 @@ async def test_jsonl_writer_deals_with_nested_user_events():
records_since_last_flush,
bytes_since_last_flush,
flush_counter,
last_inserted_at,
last_date_range,
is_last,
error,
):
assert writer.records_since_last_flush == record_batch.num_rows
in_memory_file_obj.write(batch_export_file.read())
inserted_ats_seen.append(last_inserted_at)
date_ranges_seen.append(last_date_range)
writer = JSONLBatchExportWriter(max_bytes=1, flush_callable=store_in_memory_on_flush)
@ -525,4 +533,6 @@ async def test_jsonl_writer_deals_with_nested_user_events():
assert "_inserted_at" not in written_jsonl
assert written_jsonl == {k: v for k, v in expected_jsonl.items() if k != "_inserted_at"}
assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()]
assert date_ranges_seen == [
(record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py())
]

View File

@ -16,7 +16,6 @@ async def mocked_start_batch_export_run(inputs: StartBatchExportRunInputs) -> st
data_interval_start=inputs.data_interval_start,
data_interval_end=inputs.data_interval_end,
status=BatchExportRun.Status.STARTING,
records_total_count=1,
)
return str(run.id)

View File

@ -1,6 +1,7 @@
"""Test utilities that deal with test event generation."""
import datetime as dt
import itertools
import json
import random
import typing
@ -48,41 +49,52 @@ def generate_test_events(
distinct_ids: list[str] | None = None,
):
"""Generate a list of events for testing."""
_timestamp = random.choice(possible_datetimes)
datetime_sample = random.sample(possible_datetimes, len(possible_datetimes))
datetime_cycle = itertools.cycle(datetime_sample)
_timestamp = next(datetime_cycle)
if inserted_at == "_timestamp":
inserted_at_value = _timestamp.strftime("%Y-%m-%d %H:%M:%S.%f")
elif inserted_at == "random":
inserted_at_value = random.choice(possible_datetimes).strftime("%Y-%m-%d %H:%M:%S.%f")
elif inserted_at is None:
inserted_at_value = None
if distinct_ids:
distinct_id_sample = random.sample(distinct_ids, len(distinct_ids))
distinct_id_cycle = itertools.cycle(distinct_id_sample)
else:
if not isinstance(inserted_at, dt.datetime):
raise ValueError(f"Unsupported value for inserted_at: '{inserted_at}'")
inserted_at_value = inserted_at.strftime("%Y-%m-%d %H:%M:%S.%f")
distinct_id_cycle = None
events: list[EventValues] = [
{
"_timestamp": _timestamp.strftime("%Y-%m-%d %H:%M:%S"),
"created_at": random.choice(possible_datetimes).strftime("%Y-%m-%d %H:%M:%S.%f"),
"distinct_id": random.choice(distinct_ids) if distinct_ids else str(uuid.uuid4()),
"elements": json.dumps("css selectors;"),
"elements_chain": "css selectors;",
"event": event_name.format(i=i),
"inserted_at": inserted_at_value,
"person_id": str(uuid.uuid4()),
"person_properties": person_properties,
"properties": properties,
"team_id": team_id,
"timestamp": random.choice(possible_datetimes).strftime("%Y-%m-%d %H:%M:%S.%f"),
"uuid": str(uuid.uuid4()),
"ip": ip,
"site_url": site_url,
"set": set_field,
"set_once": set_once,
}
for i in range(start, count + start)
]
def compute_inserted_at():
if inserted_at == "_timestamp":
inserted_at_value = _timestamp.strftime("%Y-%m-%d %H:%M:%S.%f")
elif inserted_at == "random":
inserted_at_value = next(datetime_cycle).strftime("%Y-%m-%d %H:%M:%S.%f")
elif inserted_at is None:
inserted_at_value = None
else:
if not isinstance(inserted_at, dt.datetime):
raise ValueError(f"Unsupported value for inserted_at: '{inserted_at}'")
inserted_at_value = inserted_at.strftime("%Y-%m-%d %H:%M:%S.%f")
return inserted_at_value
events: list[EventValues] = []
for i in range(start, count + start):
events.append(
{
"_timestamp": _timestamp.strftime("%Y-%m-%d %H:%M:%S"),
"created_at": next(datetime_cycle).strftime("%Y-%m-%d %H:%M:%S.%f"),
"distinct_id": next(distinct_id_cycle) if distinct_id_cycle else str(uuid.uuid4()),
"elements": json.dumps("css selectors;"),
"elements_chain": "css selectors;",
"event": event_name.format(i=i),
"inserted_at": compute_inserted_at(),
"person_id": str(uuid.uuid4()),
"person_properties": person_properties,
"properties": properties,
"team_id": team_id,
"timestamp": next(datetime_cycle).strftime("%Y-%m-%d %H:%M:%S.%f"),
"uuid": str(uuid.uuid4()),
"ip": ip,
"site_url": site_url,
"set": set_field,
"set_once": set_once,
}
)
return events
@ -140,7 +152,7 @@ async def generate_test_events_in_clickhouse(
event_name: str = "test-{i}",
properties: dict | None = None,
person_properties: dict | None = None,
inserted_at: str | dt.datetime | None = "_timestamp",
inserted_at: str | dt.datetime | None = "random",
distinct_ids: list[str] | None = None,
duplicate: bool = False,
batch_size: int = 10000,