0
0
mirror of https://github.com/django/django.git synced 2024-11-25 07:59:34 +01:00
django/tests/gis_tests/utils.py
2022-02-07 20:37:05 +01:00

76 lines
2.4 KiB
Python

import copy
import unittest
from functools import wraps
from unittest import mock
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, connection
from django.db.models import Func
def skipUnlessGISLookup(*gis_lookups):
"""
Skip a test unless a database supports all of gis_lookups.
"""
def decorator(test_func):
@wraps(test_func)
def skip_wrapper(*args, **kwargs):
if any(key not in connection.ops.gis_operators for key in gis_lookups):
raise unittest.SkipTest(
"Database doesn't support all the lookups: %s"
% ", ".join(gis_lookups)
)
return test_func(*args, **kwargs)
return skip_wrapper
return decorator
_default_db = settings.DATABASES[DEFAULT_DB_ALIAS]["ENGINE"].rsplit(".")[-1]
# MySQL spatial indices can't handle NULL geometries.
gisfield_may_be_null = _default_db != "mysql"
class FuncTestMixin:
"""Assert that Func expressions aren't mutated during their as_sql()."""
def setUp(self):
def as_sql_wrapper(original_as_sql):
def inner(*args, **kwargs):
func = original_as_sql.__self__
# Resolve output_field before as_sql() so touching it in
# as_sql() won't change __dict__.
func.output_field
__dict__original = copy.deepcopy(func.__dict__)
result = original_as_sql(*args, **kwargs)
msg = (
"%s Func was mutated during compilation." % func.__class__.__name__
)
self.assertEqual(func.__dict__, __dict__original, msg)
return result
return inner
def __getattribute__(self, name):
if name != vendor_impl:
return __getattribute__original(self, name)
try:
as_sql = __getattribute__original(self, vendor_impl)
except AttributeError:
as_sql = __getattribute__original(self, "as_sql")
return as_sql_wrapper(as_sql)
vendor_impl = "as_" + connection.vendor
__getattribute__original = Func.__getattribute__
self.func_patcher = mock.patch.object(
Func, "__getattribute__", __getattribute__
)
self.func_patcher.start()
super().setUp()
def tearDown(self):
super().tearDown()
self.func_patcher.stop()