diff --git a/posthog/temporal/common/asyncpa.py b/posthog/temporal/common/asyncpa.py index fd9e1bd9e84..533f085c3d4 100644 --- a/posthog/temporal/common/asyncpa.py +++ b/posthog/temporal/common/asyncpa.py @@ -16,7 +16,7 @@ class InvalidMessageFormat(Exception): class AsyncMessageReader: """Asynchronously read PyArrow messages from bytes iterator.""" - def __init__(self, bytes_iter: typing.AsyncIterator[tuple[bytes, bool]]): + def __init__(self, bytes_iter: typing.AsyncIterator[bytes]): self._bytes = bytes_iter self._buffer = bytearray() @@ -64,7 +64,7 @@ class AsyncMessageReader: async def read_until(self, n: int) -> None: """Read from self._bytes until there are at least n bytes in self._buffer.""" while len(self._buffer) < n: - bytes, _ = await anext(self._bytes) + bytes = await anext(self._bytes) self._buffer.extend(bytes) def parse_body_size(self, metadata_flatbuffer: bytearray) -> int: @@ -105,7 +105,7 @@ class AsyncMessageReader: class AsyncRecordBatchReader: """Asynchronously read PyArrow RecordBatches from an iterator of bytes.""" - def __init__(self, bytes_iter: typing.AsyncIterator[tuple[bytes, bool]]) -> None: + def __init__(self, bytes_iter: typing.AsyncIterator[bytes]) -> None: self._reader = AsyncMessageReader(bytes_iter) self._schema: None | pa.Schema = None @@ -137,7 +137,7 @@ class AsyncRecordBatchReader: class AsyncRecordBatchProducer(AsyncRecordBatchReader): - def __init__(self, bytes_iter: typing.AsyncIterator[tuple[bytes, bool]]) -> None: + def __init__(self, bytes_iter: typing.AsyncIterator[bytes]) -> None: super().__init__(bytes_iter) async def produce(self, queue: asyncio.Queue): diff --git a/posthog/temporal/common/clickhouse.py b/posthog/temporal/common/clickhouse.py index c7f39a56471..bc618eb2dbf 100644 --- a/posthog/temporal/common/clickhouse.py +++ b/posthog/temporal/common/clickhouse.py @@ -6,6 +6,7 @@ import json import ssl import typing import uuid +import structlog import aiohttp import pyarrow as pa @@ -14,6 +15,8 @@ from django.conf import settings import posthog.temporal.common.asyncpa as asyncpa +logger = structlog.get_logger() + def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: """Encode data for ClickHouse. @@ -78,6 +81,29 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: return f"{quote_char}{str_data}{quote_char}".encode() +class ChunkBytesAsyncStreamIterator: + """Async iterator of HTTP chunk bytes. + + Similar to the class provided by aiohttp, but this allows us to control + when to stop iteration. + """ + + def __init__(self, stream: aiohttp.StreamReader) -> None: + self._stream = stream + + def __aiter__(self) -> "ChunkBytesAsyncStreamIterator": + return self + + async def __anext__(self) -> bytes: + data, end_of_chunk = await self._stream.readchunk() + + if data == b"" and end_of_chunk is False and self._stream.at_eof(): + await logger.adebug("At EOF, stopping chunk iteration") + raise StopAsyncIteration + + return data + + class ClickHouseClientNotConnected(Exception): """Exception raised when attempting to run an async query without connecting.""" @@ -386,7 +412,7 @@ class ClickHouseClient: This method makes sense when running with FORMAT ArrowStream, although we currently do not enforce this. """ async with self.apost_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response: - reader = asyncpa.AsyncRecordBatchReader(response.content.iter_chunks()) + reader = asyncpa.AsyncRecordBatchReader(ChunkBytesAsyncStreamIterator(response.content)) async for batch in reader: yield batch @@ -405,7 +431,7 @@ class ClickHouseClient: downstream consumer tasks process them from the queue. """ async with self.apost_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response: - reader = asyncpa.AsyncRecordBatchProducer(response.content.iter_chunks()) + reader = asyncpa.AsyncRecordBatchProducer(ChunkBytesAsyncStreamIterator(response.content)) await reader.produce(queue=queue) async def __aenter__(self):