mirror of
https://github.com/python/cpython.git
synced 2024-11-21 21:09:37 +01:00
bf542f8bb9
Delay free a dictionary when replacing it
212 lines
5.7 KiB
Python
212 lines
5.7 KiB
Python
import gc
|
|
import time
|
|
import unittest
|
|
import weakref
|
|
|
|
from ast import Or
|
|
from functools import partial
|
|
from threading import Thread
|
|
from unittest import TestCase
|
|
|
|
try:
|
|
import _testcapi
|
|
except ImportError:
|
|
_testcapi = None
|
|
|
|
from test.support import threading_helper
|
|
|
|
|
|
@threading_helper.requires_working_threading()
|
|
class TestDict(TestCase):
|
|
def test_racing_creation_shared_keys(self):
|
|
"""Verify that creating dictionaries is thread safe when we
|
|
have a type with shared keys"""
|
|
class C(int):
|
|
pass
|
|
|
|
self.racing_creation(C)
|
|
|
|
def test_racing_creation_no_shared_keys(self):
|
|
"""Verify that creating dictionaries is thread safe when we
|
|
have a type with an ordinary dict"""
|
|
self.racing_creation(Or)
|
|
|
|
def test_racing_creation_inline_values_invalid(self):
|
|
"""Verify that re-creating a dict after we have invalid inline values
|
|
is thread safe"""
|
|
class C:
|
|
pass
|
|
|
|
def make_obj():
|
|
a = C()
|
|
# Make object, make inline values invalid, and then delete dict
|
|
a.__dict__ = {}
|
|
del a.__dict__
|
|
return a
|
|
|
|
self.racing_creation(make_obj)
|
|
|
|
def test_racing_creation_nonmanaged_dict(self):
|
|
"""Verify that explicit creation of an unmanaged dict is thread safe
|
|
outside of the normal attribute setting code path"""
|
|
def make_obj():
|
|
def f(): pass
|
|
return f
|
|
|
|
def set(func, name, val):
|
|
# Force creation of the dict via PyObject_GenericGetDict
|
|
func.__dict__[name] = val
|
|
|
|
self.racing_creation(make_obj, set)
|
|
|
|
def racing_creation(self, cls, set=setattr):
|
|
objects = []
|
|
processed = []
|
|
|
|
OBJECT_COUNT = 100
|
|
THREAD_COUNT = 10
|
|
CUR = 0
|
|
|
|
for i in range(OBJECT_COUNT):
|
|
objects.append(cls())
|
|
|
|
def writer_func(name):
|
|
last = -1
|
|
while True:
|
|
if CUR == last:
|
|
continue
|
|
elif CUR == OBJECT_COUNT:
|
|
break
|
|
|
|
obj = objects[CUR]
|
|
set(obj, name, name)
|
|
last = CUR
|
|
processed.append(name)
|
|
|
|
writers = []
|
|
for x in range(THREAD_COUNT):
|
|
writer = Thread(target=partial(writer_func, f"a{x:02}"))
|
|
writers.append(writer)
|
|
writer.start()
|
|
|
|
for i in range(OBJECT_COUNT):
|
|
CUR = i
|
|
while len(processed) != THREAD_COUNT:
|
|
time.sleep(0.001)
|
|
processed.clear()
|
|
|
|
CUR = OBJECT_COUNT
|
|
|
|
for writer in writers:
|
|
writer.join()
|
|
|
|
for obj_idx, obj in enumerate(objects):
|
|
assert (
|
|
len(obj.__dict__) == THREAD_COUNT
|
|
), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
|
|
for i in range(THREAD_COUNT):
|
|
assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
|
|
|
|
def test_racing_set_dict(self):
|
|
"""Races assigning to __dict__ should be thread safe"""
|
|
|
|
def f(): pass
|
|
l = []
|
|
THREAD_COUNT = 10
|
|
class MyDict(dict): pass
|
|
|
|
def writer_func(l):
|
|
for i in range(1000):
|
|
d = MyDict()
|
|
l.append(weakref.ref(d))
|
|
f.__dict__ = d
|
|
|
|
lists = []
|
|
writers = []
|
|
for x in range(THREAD_COUNT):
|
|
thread_list = []
|
|
lists.append(thread_list)
|
|
writer = Thread(target=partial(writer_func, thread_list))
|
|
writers.append(writer)
|
|
|
|
for writer in writers:
|
|
writer.start()
|
|
|
|
for writer in writers:
|
|
writer.join()
|
|
|
|
f.__dict__ = {}
|
|
gc.collect()
|
|
|
|
for thread_list in lists:
|
|
for ref in thread_list:
|
|
self.assertIsNone(ref())
|
|
|
|
def test_racing_set_object_dict(self):
|
|
"""Races assigning to __dict__ should be thread safe"""
|
|
class C: pass
|
|
class MyDict(dict): pass
|
|
for cyclic in (False, True):
|
|
f = C()
|
|
f.__dict__ = {"foo": 42}
|
|
THREAD_COUNT = 10
|
|
|
|
def writer_func(l):
|
|
for i in range(1000):
|
|
if cyclic:
|
|
other_d = {}
|
|
d = MyDict({"foo": 100})
|
|
if cyclic:
|
|
d["x"] = other_d
|
|
other_d["bar"] = d
|
|
l.append(weakref.ref(d))
|
|
f.__dict__ = d
|
|
|
|
def reader_func():
|
|
for i in range(1000):
|
|
f.foo
|
|
|
|
lists = []
|
|
readers = []
|
|
writers = []
|
|
for x in range(THREAD_COUNT):
|
|
thread_list = []
|
|
lists.append(thread_list)
|
|
writer = Thread(target=partial(writer_func, thread_list))
|
|
writers.append(writer)
|
|
|
|
for x in range(THREAD_COUNT):
|
|
reader = Thread(target=partial(reader_func))
|
|
readers.append(reader)
|
|
|
|
for writer in writers:
|
|
writer.start()
|
|
for reader in readers:
|
|
reader.start()
|
|
|
|
for writer in writers:
|
|
writer.join()
|
|
|
|
for reader in readers:
|
|
reader.join()
|
|
|
|
f.__dict__ = {}
|
|
gc.collect()
|
|
gc.collect()
|
|
|
|
count = 0
|
|
ids = set()
|
|
for thread_list in lists:
|
|
for i, ref in enumerate(thread_list):
|
|
if ref() is None:
|
|
continue
|
|
count += 1
|
|
ids.add(id(ref()))
|
|
count += 1
|
|
|
|
self.assertEqual(count, 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|