mirror of
https://github.com/python/cpython.git
synced 2024-11-28 08:20:55 +01:00
78 lines
2.1 KiB
Python
78 lines
2.1 KiB
Python
|
import contextlib
|
||
|
import functools
|
||
|
import io
|
||
|
import re
|
||
|
import sqlite3
|
||
|
import test.support
|
||
|
|
||
|
|
||
|
# Helper for temporary memory databases
|
||
|
def memory_database(*args, **kwargs):
|
||
|
cx = sqlite3.connect(":memory:", *args, **kwargs)
|
||
|
return contextlib.closing(cx)
|
||
|
|
||
|
|
||
|
# Temporarily limit a database connection parameter
|
||
|
@contextlib.contextmanager
|
||
|
def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128):
|
||
|
try:
|
||
|
_prev = cx.setlimit(category, limit)
|
||
|
yield limit
|
||
|
finally:
|
||
|
cx.setlimit(category, _prev)
|
||
|
|
||
|
|
||
|
def with_tracebacks(exc, regex="", name=""):
|
||
|
"""Convenience decorator for testing callback tracebacks."""
|
||
|
def decorator(func):
|
||
|
_regex = re.compile(regex) if regex else None
|
||
|
@functools.wraps(func)
|
||
|
def wrapper(self, *args, **kwargs):
|
||
|
with test.support.catch_unraisable_exception() as cm:
|
||
|
# First, run the test with traceback enabled.
|
||
|
with check_tracebacks(self, cm, exc, _regex, name):
|
||
|
func(self, *args, **kwargs)
|
||
|
|
||
|
# Then run the test with traceback disabled.
|
||
|
func(self, *args, **kwargs)
|
||
|
return wrapper
|
||
|
return decorator
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def check_tracebacks(self, cm, exc, regex, obj_name):
|
||
|
"""Convenience context manager for testing callback tracebacks."""
|
||
|
sqlite3.enable_callback_tracebacks(True)
|
||
|
try:
|
||
|
buf = io.StringIO()
|
||
|
with contextlib.redirect_stderr(buf):
|
||
|
yield
|
||
|
|
||
|
self.assertEqual(cm.unraisable.exc_type, exc)
|
||
|
if regex:
|
||
|
msg = str(cm.unraisable.exc_value)
|
||
|
self.assertIsNotNone(regex.search(msg))
|
||
|
if obj_name:
|
||
|
self.assertEqual(cm.unraisable.object.__name__, obj_name)
|
||
|
finally:
|
||
|
sqlite3.enable_callback_tracebacks(False)
|
||
|
|
||
|
|
||
|
class MemoryDatabaseMixin:
|
||
|
|
||
|
def setUp(self):
|
||
|
self.con = sqlite3.connect(":memory:")
|
||
|
self.cur = self.con.cursor()
|
||
|
|
||
|
def tearDown(self):
|
||
|
self.cur.close()
|
||
|
self.con.close()
|
||
|
|
||
|
@property
|
||
|
def cx(self):
|
||
|
return self.con
|
||
|
|
||
|
@property
|
||
|
def cu(self):
|
||
|
return self.cur
|