diff --git a/Lib/test/support/interpreters/__init__.py b/Lib/test/support/interpreters/__init__.py index d02ffbae111..d8e6654fc96 100644 --- a/Lib/test/support/interpreters/__init__.py +++ b/Lib/test/support/interpreters/__init__.py @@ -129,6 +129,14 @@ class Interpreter: def __del__(self): self._decref() + # for pickling: + def __getnewargs__(self): + return (self._id,) + + # for pickling: + def __getstate__(self): + return None + def _decref(self): if not self._ownsref: return diff --git a/Lib/test/support/interpreters/channels.py b/Lib/test/support/interpreters/channels.py index 75a5a60f54f..f7f523b1fc5 100644 --- a/Lib/test/support/interpreters/channels.py +++ b/Lib/test/support/interpreters/channels.py @@ -38,7 +38,8 @@ class _ChannelEnd: _end = None - def __init__(self, cid): + def __new__(cls, cid): + self = super().__new__(cls) if self._end == 'send': cid = _channels._channel_id(cid, send=True, force=True) elif self._end == 'recv': @@ -46,6 +47,7 @@ class _ChannelEnd: else: raise NotImplementedError(self._end) self._id = cid + return self def __repr__(self): return f'{type(self).__name__}(id={int(self._id)})' @@ -61,6 +63,14 @@ class _ChannelEnd: return NotImplemented return other._id == self._id + # for pickling: + def __getnewargs__(self): + return (int(self._id),) + + # for pickling: + def __getstate__(self): + return None + @property def id(self): return self._id diff --git a/Lib/test/support/interpreters/queues.py b/Lib/test/support/interpreters/queues.py index f9978f0bec5..5849a1cc15e 100644 --- a/Lib/test/support/interpreters/queues.py +++ b/Lib/test/support/interpreters/queues.py @@ -18,14 +18,14 @@ __all__ = [ ] -class QueueEmpty(_queues.QueueEmpty, queue.Empty): +class QueueEmpty(QueueError, queue.Empty): """Raised from get_nowait() when the queue is empty. It is also raised from get() if it times out. """ -class QueueFull(_queues.QueueFull, queue.Full): +class QueueFull(QueueError, queue.Full): """Raised from put_nowait() when the queue is full. It is also raised from put() if it times out. @@ -66,7 +66,7 @@ class Queue: else: raise TypeError(f'id must be an int, got {id!r}') if _fmt is None: - _fmt = _queues.get_default_fmt(id) + _fmt, = _queues.get_queue_defaults(id) try: self = _known_queues[id] except KeyError: @@ -93,6 +93,14 @@ class Queue: def __hash__(self): return hash(self._id) + # for pickling: + def __getnewargs__(self): + return (self._id,) + + # for pickling: + def __getstate__(self): + return None + @property def id(self): return self._id @@ -159,9 +167,8 @@ class Queue: while True: try: _queues.put(self._id, obj, fmt) - except _queues.QueueFull as exc: + except QueueFull as exc: if timeout is not None and time.time() >= end: - exc.__class__ = QueueFull raise # re-raise time.sleep(_delay) else: @@ -174,11 +181,7 @@ class Queue: fmt = _SHARED_ONLY if syncobj else _PICKLED if fmt is _PICKLED: obj = pickle.dumps(obj) - try: - _queues.put(self._id, obj, fmt) - except _queues.QueueFull as exc: - exc.__class__ = QueueFull - raise # re-raise + _queues.put(self._id, obj, fmt) def get(self, timeout=None, *, _delay=10 / 1000, # 10 milliseconds @@ -195,9 +198,8 @@ class Queue: while True: try: obj, fmt = _queues.get(self._id) - except _queues.QueueEmpty as exc: + except QueueEmpty as exc: if timeout is not None and time.time() >= end: - exc.__class__ = QueueEmpty raise # re-raise time.sleep(_delay) else: @@ -216,8 +218,7 @@ class Queue: """ try: obj, fmt = _queues.get(self._id) - except _queues.QueueEmpty as exc: - exc.__class__ = QueueEmpty + except QueueEmpty as exc: raise # re-raise if fmt == _PICKLED: obj = pickle.loads(obj) @@ -226,4 +227,4 @@ class Queue: return obj -_queues._register_queue_type(Queue) +_queues._register_heap_types(Queue, QueueEmpty, QueueFull) diff --git a/Lib/test/test_interpreters/test_api.py b/Lib/test/test_interpreters/test_api.py index 363143fa810..3cde9bd0014 100644 --- a/Lib/test/test_interpreters/test_api.py +++ b/Lib/test/test_interpreters/test_api.py @@ -1,4 +1,5 @@ import os +import pickle import threading from textwrap import dedent import unittest @@ -261,6 +262,12 @@ class InterpreterObjectTests(TestBase): self.assertEqual(interp1, interp1) self.assertNotEqual(interp1, interp2) + def test_pickle(self): + interp = interpreters.create() + data = pickle.dumps(interp) + unpickled = pickle.loads(data) + self.assertEqual(unpickled, interp) + class TestInterpreterIsRunning(TestBase): diff --git a/Lib/test/test_interpreters/test_channels.py b/Lib/test/test_interpreters/test_channels.py index 57204e27764..7e0b82884c3 100644 --- a/Lib/test/test_interpreters/test_channels.py +++ b/Lib/test/test_interpreters/test_channels.py @@ -1,4 +1,5 @@ import importlib +import pickle import threading from textwrap import dedent import unittest @@ -100,6 +101,12 @@ class TestRecvChannelAttrs(TestBase): self.assertEqual(ch1, ch1) self.assertNotEqual(ch1, ch2) + def test_pickle(self): + ch, _ = channels.create() + data = pickle.dumps(ch) + unpickled = pickle.loads(data) + self.assertEqual(unpickled, ch) + class TestSendChannelAttrs(TestBase): @@ -125,6 +132,12 @@ class TestSendChannelAttrs(TestBase): self.assertEqual(ch1, ch1) self.assertNotEqual(ch1, ch2) + def test_pickle(self): + _, ch = channels.create() + data = pickle.dumps(ch) + unpickled = pickle.loads(data) + self.assertEqual(unpickled, ch) + class TestSendRecv(TestBase): diff --git a/Lib/test/test_interpreters/test_queues.py b/Lib/test/test_interpreters/test_queues.py index 0a1fdb41f73..d16d294b82d 100644 --- a/Lib/test/test_interpreters/test_queues.py +++ b/Lib/test/test_interpreters/test_queues.py @@ -1,20 +1,25 @@ import importlib +import pickle import threading from textwrap import dedent import unittest import time -from test.support import import_helper +from test.support import import_helper, Py_DEBUG # Raise SkipTest if subinterpreters not supported. _queues = import_helper.import_module('_xxinterpqueues') from test.support import interpreters from test.support.interpreters import queues -from .utils import _run_output, TestBase +from .utils import _run_output, TestBase as _TestBase -class TestBase(TestBase): +def get_num_queues(): + return len(_queues.list_all()) + + +class TestBase(_TestBase): def tearDown(self): - for qid in _queues.list_all(): + for qid, _ in _queues.list_all(): try: _queues.destroy(qid) except Exception: @@ -34,6 +39,58 @@ class LowLevelTests(TestBase): # See gh-115490 (https://github.com/python/cpython/issues/115490). importlib.reload(queues) + def test_create_destroy(self): + qid = _queues.create(2, 0) + _queues.destroy(qid) + self.assertEqual(get_num_queues(), 0) + with self.assertRaises(queues.QueueNotFoundError): + _queues.get(qid) + with self.assertRaises(queues.QueueNotFoundError): + _queues.destroy(qid) + + def test_not_destroyed(self): + # It should have cleaned up any remaining queues. + stdout, stderr = self.assert_python_ok( + '-c', + dedent(f""" + import {_queues.__name__} as _queues + _queues.create(2, 0) + """), + ) + self.assertEqual(stdout, '') + if Py_DEBUG: + self.assertNotEqual(stderr, '') + else: + self.assertEqual(stderr, '') + + def test_bind_release(self): + with self.subTest('typical'): + qid = _queues.create(2, 0) + _queues.bind(qid) + _queues.release(qid) + self.assertEqual(get_num_queues(), 0) + + with self.subTest('bind too much'): + qid = _queues.create(2, 0) + _queues.bind(qid) + _queues.bind(qid) + _queues.release(qid) + _queues.destroy(qid) + self.assertEqual(get_num_queues(), 0) + + with self.subTest('nested'): + qid = _queues.create(2, 0) + _queues.bind(qid) + _queues.bind(qid) + _queues.release(qid) + _queues.release(qid) + self.assertEqual(get_num_queues(), 0) + + with self.subTest('release without binding'): + qid = _queues.create(2, 0) + with self.assertRaises(queues.QueueError): + _queues.release(qid) + class QueueTests(TestBase): @@ -127,6 +184,12 @@ class QueueTests(TestBase): self.assertEqual(queue1, queue1) self.assertNotEqual(queue1, queue2) + def test_pickle(self): + queue = queues.create() + data = pickle.dumps(queue) + unpickled = pickle.loads(data) + self.assertEqual(unpickled, queue) + class TestQueueOps(TestBase): diff --git a/Modules/_interpreters_common.h b/Modules/_interpreters_common.h index 5661a26d879..07120f6ccc7 100644 --- a/Modules/_interpreters_common.h +++ b/Modules/_interpreters_common.h @@ -11,3 +11,11 @@ ensure_xid_class(PyTypeObject *cls, crossinterpdatafunc getdata) //assert(cls->tp_flags & Py_TPFLAGS_HEAPTYPE); return _PyCrossInterpreterData_RegisterClass(cls, getdata); } + +#ifdef REGISTERS_HEAP_TYPES +static int +clear_xid_class(PyTypeObject *cls) +{ + return _PyCrossInterpreterData_UnregisterClass(cls); +} +#endif diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c index 0ad184a78e8..28ec00a159d 100644 --- a/Modules/_xxinterpchannelsmodule.c +++ b/Modules/_xxinterpchannelsmodule.c @@ -17,7 +17,9 @@ #include // sched_yield() #endif +#define REGISTERS_HEAP_TYPES #include "_interpreters_common.h" +#undef REGISTERS_HEAP_TYPES /* @@ -281,17 +283,17 @@ clear_xid_types(module_state *state) { /* external types */ if (state->send_channel_type != NULL) { - (void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type); + (void)clear_xid_class(state->send_channel_type); Py_CLEAR(state->send_channel_type); } if (state->recv_channel_type != NULL) { - (void)_PyCrossInterpreterData_UnregisterClass(state->recv_channel_type); + (void)clear_xid_class(state->recv_channel_type); Py_CLEAR(state->recv_channel_type); } /* heap types */ if (state->ChannelIDType != NULL) { - (void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType); + (void)clear_xid_class(state->ChannelIDType); Py_CLEAR(state->ChannelIDType); } } @@ -2677,11 +2679,11 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv) // Clear the old values if the .py module was reloaded. if (state->send_channel_type != NULL) { - (void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type); + (void)clear_xid_class(state->send_channel_type); Py_CLEAR(state->send_channel_type); } if (state->recv_channel_type != NULL) { - (void)_PyCrossInterpreterData_UnregisterClass(state->recv_channel_type); + (void)clear_xid_class(state->recv_channel_type); Py_CLEAR(state->recv_channel_type); } @@ -2694,7 +2696,7 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv) return -1; } if (ensure_xid_class(recv, _channelend_shared) < 0) { - (void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type); + (void)clear_xid_class(state->send_channel_type); Py_CLEAR(state->send_channel_type); Py_CLEAR(state->recv_channel_type); return -1; diff --git a/Modules/_xxinterpqueuesmodule.c b/Modules/_xxinterpqueuesmodule.c index 1b76b6963ae..cb8b9e4a661 100644 --- a/Modules/_xxinterpqueuesmodule.c +++ b/Modules/_xxinterpqueuesmodule.c @@ -8,7 +8,9 @@ #include "Python.h" #include "pycore_crossinterp.h" // struct _xid +#define REGISTERS_HEAP_TYPES #include "_interpreters_common.h" +#undef REGISTERS_HEAP_TYPES #define MODULE_NAME _xxinterpqueues @@ -128,6 +130,22 @@ idarg_int64_converter(PyObject *arg, void *ptr) } +static int +ensure_highlevel_module_loaded(void) +{ + PyObject *highlevel = PyImport_ImportModule("interpreters.queues"); + if (highlevel == NULL) { + PyErr_Clear(); + highlevel = PyImport_ImportModule("test.support.interpreters.queues"); + if (highlevel == NULL) { + return -1; + } + } + Py_DECREF(highlevel); + return 0; +} + + /* module state *************************************************************/ typedef struct { @@ -170,7 +188,7 @@ clear_module_state(module_state *state) { /* external types */ if (state->queue_type != NULL) { - (void)_PyCrossInterpreterData_UnregisterClass(state->queue_type); + (void)clear_xid_class(state->queue_type); } Py_CLEAR(state->queue_type); @@ -195,6 +213,9 @@ clear_module_state(module_state *state) // single-queue errors #define ERR_QUEUE_EMPTY (-21) #define ERR_QUEUE_FULL (-22) +#define ERR_QUEUE_NEVER_BOUND (-23) + +static int ensure_external_exc_types(module_state *); static int resolve_module_errcode(module_state *state, int errcode, int64_t qid, @@ -212,13 +233,23 @@ resolve_module_errcode(module_state *state, int errcode, int64_t qid, msg = PyUnicode_FromFormat("queue %" PRId64 " not found", qid); break; case ERR_QUEUE_EMPTY: + if (ensure_external_exc_types(state) < 0) { + return -1; + } exctype = state->QueueEmpty; msg = PyUnicode_FromFormat("queue %" PRId64 " is empty", qid); break; case ERR_QUEUE_FULL: + if (ensure_external_exc_types(state) < 0) { + return -1; + } exctype = state->QueueFull; msg = PyUnicode_FromFormat("queue %" PRId64 " is full", qid); break; + case ERR_QUEUE_NEVER_BOUND: + exctype = state->QueueError; + msg = PyUnicode_FromFormat("queue %" PRId64 " never bound", qid); + break; default: PyErr_Format(PyExc_ValueError, "unsupported error code %d", errcode); @@ -267,20 +298,59 @@ add_QueueError(PyObject *mod) #define PREFIX "test.support.interpreters." #define ADD_EXCTYPE(NAME, BASE, DOC) \ + assert(state->NAME == NULL); \ if (add_exctype(mod, &state->NAME, PREFIX #NAME, DOC, BASE) < 0) { \ return -1; \ } ADD_EXCTYPE(QueueError, PyExc_RuntimeError, "Indicates that a queue-related error happened.") ADD_EXCTYPE(QueueNotFoundError, state->QueueError, NULL) - ADD_EXCTYPE(QueueEmpty, state->QueueError, NULL) - ADD_EXCTYPE(QueueFull, state->QueueError, NULL) + // QueueEmpty and QueueFull are set by set_external_exc_types(). + state->QueueEmpty = NULL; + state->QueueFull = NULL; #undef ADD_EXCTYPE #undef PREFIX return 0; } +static int +set_external_exc_types(module_state *state, + PyObject *emptyerror, PyObject *fullerror) +{ + if (state->QueueEmpty != NULL) { + assert(state->QueueFull != NULL); + Py_CLEAR(state->QueueEmpty); + Py_CLEAR(state->QueueFull); + } + else { + assert(state->QueueFull == NULL); + } + assert(PyObject_IsSubclass(emptyerror, state->QueueError)); + assert(PyObject_IsSubclass(fullerror, state->QueueError)); + state->QueueEmpty = Py_NewRef(emptyerror); + state->QueueFull = Py_NewRef(fullerror); + return 0; +} + +static int +ensure_external_exc_types(module_state *state) +{ + if (state->QueueEmpty != NULL) { + assert(state->QueueFull != NULL); + return 0; + } + assert(state->QueueFull == NULL); + + // Force the module to be loaded, to register the type. + if (ensure_highlevel_module_loaded() < 0) { + return -1; + } + assert(state->QueueEmpty != NULL); + assert(state->QueueFull != NULL); + return 0; +} + static int handle_queue_error(int err, PyObject *mod, int64_t qid) { @@ -393,6 +463,7 @@ _queueitem_popped(_queueitem *item, /* the queue */ + typedef struct _queue { Py_ssize_t num_waiters; // protected by global lock PyThread_type_lock mutex; @@ -435,6 +506,8 @@ _queue_clear(_queue *queue) *queue = (_queue){0}; } +static void _queue_free(_queue *); + static void _queue_kill_and_wait(_queue *queue) { @@ -667,6 +740,32 @@ _queuerefs_find(_queueref *first, int64_t qid, _queueref **pprev) return ref; } +static void +_queuerefs_clear(_queueref *head) +{ + _queueref *next = head; + while (next != NULL) { + _queueref *ref = next; + next = ref->next; + +#ifdef Py_DEBUG + int64_t qid = ref->qid; + fprintf(stderr, "queue %ld still exists\n", qid); +#endif + _queue *queue = ref->queue; + GLOBAL_FREE(ref); + + _queue_kill_and_wait(queue); +#ifdef Py_DEBUG + if (queue->items.count > 0) { + fprintf(stderr, "queue %ld still holds %ld items\n", + qid, queue->items.count); + } +#endif + _queue_free(queue); + } +} + /* a collection of queues ***************************************************/ @@ -689,8 +788,15 @@ _queues_init(_queues *queues, PyThread_type_lock mutex) static void _queues_fini(_queues *queues) { - assert(queues->count == 0); - assert(queues->head == NULL); + if (queues->count > 0) { + PyThread_acquire_lock(queues->mutex, WAIT_LOCK); + assert((queues->count == 0) != (queues->head != NULL)); + _queueref *head = queues->head; + queues->head = NULL; + queues->count = 0; + PyThread_release_lock(queues->mutex); + _queuerefs_clear(head); + } if (queues->mutex != NULL) { PyThread_free_lock(queues->mutex); queues->mutex = NULL; @@ -822,19 +928,21 @@ done: return res; } -static void _queue_free(_queue *); - -static void +static int _queues_decref(_queues *queues, int64_t qid) { + int res = -1; PyThread_acquire_lock(queues->mutex, WAIT_LOCK); _queueref *prev = NULL; _queueref *ref = _queuerefs_find(queues->head, qid, &prev); if (ref == NULL) { assert(!PyErr_Occurred()); - // Already destroyed. - // XXX Warn? + res = ERR_QUEUE_NOT_FOUND; + goto finally; + } + if (ref->refcount == 0) { + res = ERR_QUEUE_NEVER_BOUND; goto finally; } assert(ref->refcount > 0); @@ -849,11 +957,13 @@ _queues_decref(_queues *queues, int64_t qid) _queue_kill_and_wait(queue); _queue_free(queue); - return; + return 0; } + res = 0; finally: PyThread_release_lock(queues->mutex); + return res; } struct queue_id_and_fmt { @@ -1077,14 +1187,11 @@ static int _queueobj_shared(PyThreadState *, PyObject *, _PyCrossInterpreterData *); static int -set_external_queue_type(PyObject *module, PyTypeObject *queue_type) +set_external_queue_type(module_state *state, PyTypeObject *queue_type) { - module_state *state = get_module_state(module); - // Clear the old value if the .py module was reloaded. if (state->queue_type != NULL) { - (void)_PyCrossInterpreterData_UnregisterClass( - state->queue_type); + (void)clear_xid_class(state->queue_type); Py_CLEAR(state->queue_type); } @@ -1105,15 +1212,9 @@ get_external_queue_type(PyObject *module) PyTypeObject *cls = state->queue_type; if (cls == NULL) { // Force the module to be loaded, to register the type. - PyObject *highlevel = PyImport_ImportModule("interpreters.queue"); - if (highlevel == NULL) { - PyErr_Clear(); - highlevel = PyImport_ImportModule("test.support.interpreters.queue"); - if (highlevel == NULL) { - return NULL; - } + if (ensure_highlevel_module_loaded() < 0) { + return NULL; } - Py_DECREF(highlevel); cls = state->queue_type; assert(cls != NULL); } @@ -1152,7 +1253,14 @@ _queueid_xid_free(void *data) int64_t qid = ((struct _queueid_xid *)data)->qid; PyMem_RawFree(data); _queues *queues = _get_global_queues(); - _queues_decref(queues, qid); + int res = _queues_decref(queues, qid); + if (res == ERR_QUEUE_NOT_FOUND) { + // Already destroyed. + // XXX Warn? + } + else { + assert(res == 0); + } } static PyObject * @@ -1319,10 +1427,13 @@ queuesmod_create(PyObject *self, PyObject *args, PyObject *kwds) } PyDoc_STRVAR(queuesmod_create_doc, -"create() -> qid\n\ +"create(maxsize, fmt) -> qid\n\ \n\ Create a new cross-interpreter queue and return its unique generated ID.\n\ -It is a new reference as though bind() had been called on the queue."); +It is a new reference as though bind() had been called on the queue.\n\ +\n\ +The caller is responsible for calling destroy() for the new queue\n\ +before the runtime is finalized."); static PyObject * queuesmod_destroy(PyObject *self, PyObject *args, PyObject *kwds) @@ -1379,9 +1490,10 @@ finally: } PyDoc_STRVAR(queuesmod_list_all_doc, -"list_all() -> [qid]\n\ +"list_all() -> [(qid, fmt)]\n\ \n\ -Return the list of IDs for all queues."); +Return the list of IDs for all queues.\n\ +Each corresponding default format is also included."); static PyObject * queuesmod_put(PyObject *self, PyObject *args, PyObject *kwds) @@ -1398,6 +1510,7 @@ queuesmod_put(PyObject *self, PyObject *args, PyObject *kwds) /* Queue up the object. */ int err = queue_put(&_globals.queues, qid, obj, fmt); + // This is the only place that raises QueueFull. if (handle_queue_error(err, self, qid)) { return NULL; } @@ -1406,18 +1519,17 @@ queuesmod_put(PyObject *self, PyObject *args, PyObject *kwds) } PyDoc_STRVAR(queuesmod_put_doc, -"put(qid, obj, sharedonly=False)\n\ +"put(qid, obj, fmt)\n\ \n\ Add the object's data to the queue."); static PyObject * queuesmod_get(PyObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"qid", "default", NULL}; + static char *kwlist[] = {"qid", NULL}; qidarg_converter_data qidarg; - PyObject *dflt = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|O:get", kwlist, - qidarg_converter, &qidarg, &dflt)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:get", kwlist, + qidarg_converter, &qidarg)) { return NULL; } int64_t qid = qidarg.id; @@ -1425,11 +1537,8 @@ queuesmod_get(PyObject *self, PyObject *args, PyObject *kwds) PyObject *obj = NULL; int fmt = 0; int err = queue_get(&_globals.queues, qid, &obj, &fmt); - if (err == ERR_QUEUE_EMPTY && dflt != NULL) { - assert(obj == NULL); - obj = Py_NewRef(dflt); - } - else if (handle_queue_error(err, self, qid)) { + // This is the only place that raises QueueEmpty. + if (handle_queue_error(err, self, qid)) { return NULL; } @@ -1439,12 +1548,12 @@ queuesmod_get(PyObject *self, PyObject *args, PyObject *kwds) } PyDoc_STRVAR(queuesmod_get_doc, -"get(qid, [default]) -> obj\n\ +"get(qid) -> (obj, fmt)\n\ \n\ Return a new object from the data at the front of the queue.\n\ +The object's format is also returned.\n\ \n\ -If there is nothing to receive then raise QueueEmpty, unless\n\ -a default value is provided. In that case return it."); +If there is nothing to receive then raise QueueEmpty."); static PyObject * queuesmod_bind(PyObject *self, PyObject *args, PyObject *kwds) @@ -1491,7 +1600,10 @@ queuesmod_release(PyObject *self, PyObject *args, PyObject *kwds) // XXX Check module state if bound already. // XXX Update module state. - _queues_decref(&_globals.queues, qid); + int err = _queues_decref(&_globals.queues, qid); + if (handle_queue_error(err, self, qid)) { + return NULL; + } Py_RETURN_NONE; } @@ -1528,12 +1640,12 @@ PyDoc_STRVAR(queuesmod_get_maxsize_doc, Return the maximum number of items in the queue."); static PyObject * -queuesmod_get_default_fmt(PyObject *self, PyObject *args, PyObject *kwds) +queuesmod_get_queue_defaults(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"qid", NULL}; qidarg_converter_data qidarg; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O&:get_default_fmt", kwlist, + "O&:get_queue_defaults", kwlist, qidarg_converter, &qidarg)) { return NULL; } @@ -1546,13 +1658,21 @@ queuesmod_get_default_fmt(PyObject *self, PyObject *args, PyObject *kwds) } int fmt = queue->fmt; _queue_unmark_waiter(queue, _globals.queues.mutex); - return PyLong_FromLong(fmt); + + PyObject *fmt_obj = PyLong_FromLong(fmt); + if (fmt_obj == NULL) { + return NULL; + } + // For now queues only have one default. + PyObject *res = PyTuple_Pack(1, fmt_obj); + Py_DECREF(fmt_obj); + return res; } -PyDoc_STRVAR(queuesmod_get_default_fmt_doc, -"get_default_fmt(qid)\n\ +PyDoc_STRVAR(queuesmod_get_queue_defaults_doc, +"get_queue_defaults(qid)\n\ \n\ -Return the default format to use for the queue."); +Return the queue's default values, set when it was created."); static PyObject * queuesmod_is_full(PyObject *self, PyObject *args, PyObject *kwds) @@ -1609,22 +1729,39 @@ PyDoc_STRVAR(queuesmod_get_count_doc, Return the number of items in the queue."); static PyObject * -queuesmod__register_queue_type(PyObject *self, PyObject *args, PyObject *kwds) +queuesmod__register_heap_types(PyObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"queuetype", NULL}; + static char *kwlist[] = {"queuetype", "emptyerror", "fullerror", NULL}; PyObject *queuetype; + PyObject *emptyerror; + PyObject *fullerror; if (!PyArg_ParseTupleAndKeywords(args, kwds, - "O:_register_queue_type", kwlist, - &queuetype)) { + "OOO:_register_heap_types", kwlist, + &queuetype, &emptyerror, &fullerror)) { return NULL; } if (!PyType_Check(queuetype)) { - PyErr_SetString(PyExc_TypeError, "expected a type for 'queuetype'"); + PyErr_SetString(PyExc_TypeError, + "expected a type for 'queuetype'"); + return NULL; + } + if (!PyExceptionClass_Check(emptyerror)) { + PyErr_SetString(PyExc_TypeError, + "expected an exception type for 'emptyerror'"); + return NULL; + } + if (!PyExceptionClass_Check(fullerror)) { + PyErr_SetString(PyExc_TypeError, + "expected an exception type for 'fullerror'"); return NULL; } - PyTypeObject *cls_queue = (PyTypeObject *)queuetype; - if (set_external_queue_type(self, cls_queue) < 0) { + module_state *state = get_module_state(self); + + if (set_external_queue_type(state, (PyTypeObject *)queuetype) < 0) { + return NULL; + } + if (set_external_exc_types(state, emptyerror, fullerror) < 0) { return NULL; } @@ -1638,23 +1775,23 @@ static PyMethodDef module_functions[] = { METH_VARARGS | METH_KEYWORDS, queuesmod_destroy_doc}, {"list_all", queuesmod_list_all, METH_NOARGS, queuesmod_list_all_doc}, - {"put", _PyCFunction_CAST(queuesmod_put), + {"put", _PyCFunction_CAST(queuesmod_put), METH_VARARGS | METH_KEYWORDS, queuesmod_put_doc}, - {"get", _PyCFunction_CAST(queuesmod_get), + {"get", _PyCFunction_CAST(queuesmod_get), METH_VARARGS | METH_KEYWORDS, queuesmod_get_doc}, - {"bind", _PyCFunction_CAST(queuesmod_bind), + {"bind", _PyCFunction_CAST(queuesmod_bind), METH_VARARGS | METH_KEYWORDS, queuesmod_bind_doc}, {"release", _PyCFunction_CAST(queuesmod_release), METH_VARARGS | METH_KEYWORDS, queuesmod_release_doc}, {"get_maxsize", _PyCFunction_CAST(queuesmod_get_maxsize), METH_VARARGS | METH_KEYWORDS, queuesmod_get_maxsize_doc}, - {"get_default_fmt", _PyCFunction_CAST(queuesmod_get_default_fmt), - METH_VARARGS | METH_KEYWORDS, queuesmod_get_default_fmt_doc}, + {"get_queue_defaults", _PyCFunction_CAST(queuesmod_get_queue_defaults), + METH_VARARGS | METH_KEYWORDS, queuesmod_get_queue_defaults_doc}, {"is_full", _PyCFunction_CAST(queuesmod_is_full), METH_VARARGS | METH_KEYWORDS, queuesmod_is_full_doc}, {"get_count", _PyCFunction_CAST(queuesmod_get_count), METH_VARARGS | METH_KEYWORDS, queuesmod_get_count_doc}, - {"_register_queue_type", _PyCFunction_CAST(queuesmod__register_queue_type), + {"_register_heap_types", _PyCFunction_CAST(queuesmod__register_heap_types), METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL} /* sentinel */