mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-28 18:26:15 +01:00
110 lines
3.9 KiB
Python
110 lines
3.9 KiB
Python
import re
|
||
from contextlib import contextmanager
|
||
from functools import wraps
|
||
from typing import Any
|
||
from unittest.mock import patch
|
||
|
||
import pytest
|
||
import sqlparse
|
||
from django.db import DEFAULT_DB_ALIAS
|
||
|
||
from ee.clickhouse.client import ch_pool, sync_execute
|
||
from ee.clickhouse.sql.events import DROP_EVENTS_TABLE_SQL, EVENTS_TABLE_SQL
|
||
from ee.clickhouse.sql.person import DROP_PERSON_TABLE_SQL, PERSONS_TABLE_SQL
|
||
from posthog.test.base import BaseTest
|
||
|
||
|
||
@pytest.mark.usefixtures("unittest_snapshot")
|
||
class ClickhouseTestMixin:
|
||
RUN_MATERIALIZED_COLUMN_TESTS = True
|
||
# overrides the basetest in posthog/test/base.py
|
||
# this way the team id will increment so we don't have to destroy all clickhouse tables on each test
|
||
CLASS_DATA_LEVEL_SETUP = False
|
||
|
||
snapshot: Any
|
||
|
||
@contextmanager
|
||
def _assertNumQueries(self, func):
|
||
yield
|
||
|
||
# Ignore assertNumQueries in clickhouse tests
|
||
def assertNumQueries(self, num, func=None, *args, using=DEFAULT_DB_ALIAS, **kwargs):
|
||
return self._assertNumQueries(func)
|
||
|
||
# :NOTE: Update snapshots by passing --snapshot-update to bin/tests
|
||
def assertQueryMatchesSnapshot(self, query, params=None):
|
||
# :TRICKY: team_id changes every test, avoid it messing with snapshots.
|
||
query = re.sub(r"(team|cohort)_id = \d+", r"\1_id = 2", query)
|
||
|
||
assert sqlparse.format(query, reindent=True) == self.snapshot, "\n".join(self.snapshot.get_assert_diff())
|
||
if params is not None:
|
||
del params["team_id"] # Changes every run
|
||
assert params == self.snapshot, "\n".join(self.snapshot.get_assert_diff())
|
||
|
||
@contextmanager
|
||
def capture_select_queries(self):
|
||
queries = []
|
||
original_get_client = ch_pool.get_client
|
||
|
||
# Spy on the `clichhouse_driver.Client.execute` method. This is a bit of
|
||
# a roundabout way to handle this, but it seems tricky to spy on the
|
||
# unbound class method `Client.execute` directly easily
|
||
@contextmanager
|
||
def get_client():
|
||
with original_get_client() as client:
|
||
original_client_execute = client.execute
|
||
|
||
def execute_wrapper(query, *args, **kwargs):
|
||
if sqlparse.format(query, strip_comments=True).strip().startswith(("SELECT", "WITH")):
|
||
queries.append(query)
|
||
return original_client_execute(query, *args, **kwargs)
|
||
|
||
with patch.object(client, "execute", wraps=execute_wrapper) as _:
|
||
yield client
|
||
|
||
with patch("ee.clickhouse.client.ch_pool.get_client", wraps=get_client) as _:
|
||
yield queries
|
||
|
||
|
||
class ClickhouseDestroyTablesMixin(BaseTest):
|
||
"""
|
||
To speed up tests we normally don't destroy the tables between tests, so clickhouse tables will have data from previous tests.
|
||
Use this mixin to make sure you completely destroy the tables between tests.
|
||
"""
|
||
|
||
def setUp(self):
|
||
super().setUp()
|
||
sync_execute(DROP_EVENTS_TABLE_SQL)
|
||
sync_execute(EVENTS_TABLE_SQL)
|
||
sync_execute(DROP_PERSON_TABLE_SQL)
|
||
sync_execute(PERSONS_TABLE_SQL)
|
||
|
||
def tearDown(self):
|
||
super().tearDown()
|
||
sync_execute(DROP_EVENTS_TABLE_SQL)
|
||
sync_execute(EVENTS_TABLE_SQL)
|
||
sync_execute(DROP_PERSON_TABLE_SQL)
|
||
sync_execute(PERSONS_TABLE_SQL)
|
||
|
||
|
||
def snapshot_clickhouse_queries(fn):
|
||
"""
|
||
Captures and snapshots select queries from test using `syrupy` library.
|
||
|
||
Requires queries to be stable to avoid flakiness.
|
||
|
||
Snapshots are automatically saved in a __snapshot__/*.ambr file.
|
||
Update snapshots via --snapshot-update.
|
||
"""
|
||
|
||
@wraps(fn)
|
||
def wrapped(self, *args, **kwargs):
|
||
with self.capture_select_queries() as queries:
|
||
fn(self, *args, **kwargs)
|
||
|
||
for query in queries:
|
||
if "FROM system.columns" not in query:
|
||
self.assertQueryMatchesSnapshot(query)
|
||
|
||
return wrapped
|