0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-28 18:26:15 +01:00
posthog/ee/clickhouse/util.py
2021-11-11 11:30:50 +00:00

110 lines
3.9 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from contextlib import contextmanager
from functools import wraps
from typing import Any
from unittest.mock import patch
import pytest
import sqlparse
from django.db import DEFAULT_DB_ALIAS
from ee.clickhouse.client import ch_pool, sync_execute
from ee.clickhouse.sql.events import DROP_EVENTS_TABLE_SQL, EVENTS_TABLE_SQL
from ee.clickhouse.sql.person import DROP_PERSON_TABLE_SQL, PERSONS_TABLE_SQL
from posthog.test.base import BaseTest
@pytest.mark.usefixtures("unittest_snapshot")
class ClickhouseTestMixin:
RUN_MATERIALIZED_COLUMN_TESTS = True
# overrides the basetest in posthog/test/base.py
#  this way the team id will increment so we don't have to destroy all clickhouse tables on each test
CLASS_DATA_LEVEL_SETUP = False
snapshot: Any
@contextmanager
def _assertNumQueries(self, func):
yield
# Ignore assertNumQueries in clickhouse tests
def assertNumQueries(self, num, func=None, *args, using=DEFAULT_DB_ALIAS, **kwargs):
return self._assertNumQueries(func)
# :NOTE: Update snapshots by passing --snapshot-update to bin/tests
def assertQueryMatchesSnapshot(self, query, params=None):
# :TRICKY: team_id changes every test, avoid it messing with snapshots.
query = re.sub(r"(team|cohort)_id = \d+", r"\1_id = 2", query)
assert sqlparse.format(query, reindent=True) == self.snapshot, "\n".join(self.snapshot.get_assert_diff())
if params is not None:
del params["team_id"] # Changes every run
assert params == self.snapshot, "\n".join(self.snapshot.get_assert_diff())
@contextmanager
def capture_select_queries(self):
queries = []
original_get_client = ch_pool.get_client
# Spy on the `clichhouse_driver.Client.execute` method. This is a bit of
# a roundabout way to handle this, but it seems tricky to spy on the
# unbound class method `Client.execute` directly easily
@contextmanager
def get_client():
with original_get_client() as client:
original_client_execute = client.execute
def execute_wrapper(query, *args, **kwargs):
if sqlparse.format(query, strip_comments=True).strip().startswith(("SELECT", "WITH")):
queries.append(query)
return original_client_execute(query, *args, **kwargs)
with patch.object(client, "execute", wraps=execute_wrapper) as _:
yield client
with patch("ee.clickhouse.client.ch_pool.get_client", wraps=get_client) as _:
yield queries
class ClickhouseDestroyTablesMixin(BaseTest):
"""
To speed up tests we normally don't destroy the tables between tests, so clickhouse tables will have data from previous tests.
Use this mixin to make sure you completely destroy the tables between tests.
"""
def setUp(self):
super().setUp()
sync_execute(DROP_EVENTS_TABLE_SQL)
sync_execute(EVENTS_TABLE_SQL)
sync_execute(DROP_PERSON_TABLE_SQL)
sync_execute(PERSONS_TABLE_SQL)
def tearDown(self):
super().tearDown()
sync_execute(DROP_EVENTS_TABLE_SQL)
sync_execute(EVENTS_TABLE_SQL)
sync_execute(DROP_PERSON_TABLE_SQL)
sync_execute(PERSONS_TABLE_SQL)
def snapshot_clickhouse_queries(fn):
"""
Captures and snapshots select queries from test using `syrupy` library.
Requires queries to be stable to avoid flakiness.
Snapshots are automatically saved in a __snapshot__/*.ambr file.
Update snapshots via --snapshot-update.
"""
@wraps(fn)
def wrapped(self, *args, **kwargs):
with self.capture_select_queries() as queries:
fn(self, *args, **kwargs)
for query in queries:
if "FROM system.columns" not in query:
self.assertQueryMatchesSnapshot(query)
return wrapped