mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-21 13:39:22 +01:00
feat: Implement chunk iteration for batch exports (#26292)
This commit is contained in:
parent
551290b688
commit
a0d32aa185
@ -16,7 +16,7 @@ class InvalidMessageFormat(Exception):
|
||||
class AsyncMessageReader:
|
||||
"""Asynchronously read PyArrow messages from bytes iterator."""
|
||||
|
||||
def __init__(self, bytes_iter: typing.AsyncIterator[tuple[bytes, bool]]):
|
||||
def __init__(self, bytes_iter: typing.AsyncIterator[bytes]):
|
||||
self._bytes = bytes_iter
|
||||
self._buffer = bytearray()
|
||||
|
||||
@ -64,7 +64,7 @@ class AsyncMessageReader:
|
||||
async def read_until(self, n: int) -> None:
|
||||
"""Read from self._bytes until there are at least n bytes in self._buffer."""
|
||||
while len(self._buffer) < n:
|
||||
bytes, _ = await anext(self._bytes)
|
||||
bytes = await anext(self._bytes)
|
||||
self._buffer.extend(bytes)
|
||||
|
||||
def parse_body_size(self, metadata_flatbuffer: bytearray) -> int:
|
||||
@ -105,7 +105,7 @@ class AsyncMessageReader:
|
||||
class AsyncRecordBatchReader:
|
||||
"""Asynchronously read PyArrow RecordBatches from an iterator of bytes."""
|
||||
|
||||
def __init__(self, bytes_iter: typing.AsyncIterator[tuple[bytes, bool]]) -> None:
|
||||
def __init__(self, bytes_iter: typing.AsyncIterator[bytes]) -> None:
|
||||
self._reader = AsyncMessageReader(bytes_iter)
|
||||
self._schema: None | pa.Schema = None
|
||||
|
||||
@ -137,7 +137,7 @@ class AsyncRecordBatchReader:
|
||||
|
||||
|
||||
class AsyncRecordBatchProducer(AsyncRecordBatchReader):
|
||||
def __init__(self, bytes_iter: typing.AsyncIterator[tuple[bytes, bool]]) -> None:
|
||||
def __init__(self, bytes_iter: typing.AsyncIterator[bytes]) -> None:
|
||||
super().__init__(bytes_iter)
|
||||
|
||||
async def produce(self, queue: asyncio.Queue):
|
||||
|
@ -6,6 +6,7 @@ import json
|
||||
import ssl
|
||||
import typing
|
||||
import uuid
|
||||
import structlog
|
||||
|
||||
import aiohttp
|
||||
import pyarrow as pa
|
||||
@ -14,6 +15,8 @@ from django.conf import settings
|
||||
|
||||
import posthog.temporal.common.asyncpa as asyncpa
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes:
|
||||
"""Encode data for ClickHouse.
|
||||
@ -78,6 +81,29 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes:
|
||||
return f"{quote_char}{str_data}{quote_char}".encode()
|
||||
|
||||
|
||||
class ChunkBytesAsyncStreamIterator:
|
||||
"""Async iterator of HTTP chunk bytes.
|
||||
|
||||
Similar to the class provided by aiohttp, but this allows us to control
|
||||
when to stop iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, stream: aiohttp.StreamReader) -> None:
|
||||
self._stream = stream
|
||||
|
||||
def __aiter__(self) -> "ChunkBytesAsyncStreamIterator":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
data, end_of_chunk = await self._stream.readchunk()
|
||||
|
||||
if data == b"" and end_of_chunk is False and self._stream.at_eof():
|
||||
await logger.adebug("At EOF, stopping chunk iteration")
|
||||
raise StopAsyncIteration
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ClickHouseClientNotConnected(Exception):
|
||||
"""Exception raised when attempting to run an async query without connecting."""
|
||||
|
||||
@ -386,7 +412,7 @@ class ClickHouseClient:
|
||||
This method makes sense when running with FORMAT ArrowStream, although we currently do not enforce this.
|
||||
"""
|
||||
async with self.apost_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response:
|
||||
reader = asyncpa.AsyncRecordBatchReader(response.content.iter_chunks())
|
||||
reader = asyncpa.AsyncRecordBatchReader(ChunkBytesAsyncStreamIterator(response.content))
|
||||
async for batch in reader:
|
||||
yield batch
|
||||
|
||||
@ -405,7 +431,7 @@ class ClickHouseClient:
|
||||
downstream consumer tasks process them from the queue.
|
||||
"""
|
||||
async with self.apost_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response:
|
||||
reader = asyncpa.AsyncRecordBatchProducer(response.content.iter_chunks())
|
||||
reader = asyncpa.AsyncRecordBatchProducer(ChunkBytesAsyncStreamIterator(response.content))
|
||||
await reader.produce(queue=queue)
|
||||
|
||||
async def __aenter__(self):
|
||||
|
Loading…
Reference in New Issue
Block a user