From 3032fcd90ecb745b737cbc93f694f9a802062a3a Mon Sep 17 00:00:00 2001 From: Nice Zombies Date: Mon, 4 Nov 2024 15:00:19 +0100 Subject: [PATCH] gh-119793: Add optional length-checking to `map()` (GH-120471) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com> Co-authored-by: Pieter Eendebak Co-authored-by: Erlend E. Aasland Co-authored-by: Raymond Hettinger --- Doc/library/functions.rst | 11 +- Doc/whatsnew/3.14.rst | 4 + Lib/test/test_builtin.py | 105 ++++++++++++++++++ Lib/test/test_itertools.py | 4 +- ...-06-13-19-12-49.gh-issue-119793.FDVCDk.rst | 3 + Python/bltinmodule.c | 100 +++++++++++++++-- 6 files changed, 210 insertions(+), 17 deletions(-) create mode 100644 Misc/NEWS.d/next/Core_and_Builtins/2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst diff --git a/Doc/library/functions.rst b/Doc/library/functions.rst index 03fc41fa793..a7549b9bce7 100644 --- a/Doc/library/functions.rst +++ b/Doc/library/functions.rst @@ -1205,14 +1205,19 @@ are always available. They are listed here in alphabetical order. unchanged from previous versions. -.. function:: map(function, iterable, *iterables) +.. function:: map(function, iterable, /, *iterables, strict=False) Return an iterator that applies *function* to every item of *iterable*, yielding the results. If additional *iterables* arguments are passed, *function* must take that many arguments and is applied to the items from all iterables in parallel. With multiple iterables, the iterator stops when the - shortest iterable is exhausted. For cases where the function inputs are - already arranged into argument tuples, see :func:`itertools.starmap`\. + shortest iterable is exhausted. If *strict* is ``True`` and one of the + iterables is exhausted before the others, a :exc:`ValueError` is raised. For + cases where the function inputs are already arranged into argument tuples, + see :func:`itertools.starmap`. + + .. versionchanged:: 3.14 + Added the *strict* parameter. .. function:: max(iterable, *, key=None) diff --git a/Doc/whatsnew/3.14.rst b/Doc/whatsnew/3.14.rst index deee683d7b8..80c1a93b95a 100644 --- a/Doc/whatsnew/3.14.rst +++ b/Doc/whatsnew/3.14.rst @@ -175,6 +175,10 @@ Improved error messages Other language changes ====================== +* The :func:`map` built-in now has an optional keyword-only *strict* flag + like :func:`zip` to check that all the iterables are of equal length. + (Contributed by Wannes Boeykens in :gh:`119793`.) + * Incorrect usage of :keyword:`await` and asynchronous comprehensions is now detected even if the code is optimized away by the :option:`-O` command-line option. For example, ``python -O -c 'assert await 1'`` diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index eb5906f8944..f8e6f05cd60 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -148,6 +148,9 @@ def filter_char(arg): def map_char(arg): return chr(ord(arg)+1) +def pack(*args): + return args + class BuiltinTest(unittest.TestCase): # Helper to check picklability def check_iter_pickle(self, it, seq, proto): @@ -1269,6 +1272,108 @@ class BuiltinTest(unittest.TestCase): m2 = map(map_char, "Is this the real life?") self.check_iter_pickle(m1, list(m2), proto) + # strict map tests based on strict zip tests + + def test_map_pickle_strict(self): + a = (1, 2, 3) + b = (4, 5, 6) + t = [(1, 4), (2, 5), (3, 6)] + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + m1 = map(pack, a, b, strict=True) + self.check_iter_pickle(m1, t, proto) + + def test_map_pickle_strict_fail(self): + a = (1, 2, 3) + b = (4, 5, 6, 7) + t = [(1, 4), (2, 5), (3, 6)] + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + m1 = map(pack, a, b, strict=True) + m2 = pickle.loads(pickle.dumps(m1, proto)) + self.assertEqual(self.iter_error(m1, ValueError), t) + self.assertEqual(self.iter_error(m2, ValueError), t) + + def test_map_strict(self): + self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)), + ((1, 'a'), (2, 'b'), (3, 'c'))) + self.assertRaises(ValueError, tuple, + map(pack, (1, 2, 3, 4), 'abc', strict=True)) + self.assertRaises(ValueError, tuple, + map(pack, (1, 2), 'abc', strict=True)) + self.assertRaises(ValueError, tuple, + map(pack, (1, 2), (1, 2), 'abc', strict=True)) + + def test_map_strict_iterators(self): + x = iter(range(5)) + y = [0] + z = iter(range(5)) + self.assertRaises(ValueError, list, + (map(pack, x, y, z, strict=True))) + self.assertEqual(next(x), 2) + self.assertEqual(next(z), 1) + + def test_map_strict_error_handling(self): + + class Error(Exception): + pass + + class Iter: + def __init__(self, size): + self.size = size + def __iter__(self): + return self + def __next__(self): + self.size -= 1 + if self.size < 0: + raise Error + return self.size + + l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), Error) + self.assertEqual(l1, [("A", 0)]) + l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError) + self.assertEqual(l2, [("A", 1, "A")]) + l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), Error) + self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")]) + l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError) + self.assertEqual(l4, [("A", 2), ("B", 1)]) + l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), Error) + self.assertEqual(l5, [(0, "A")]) + l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError) + self.assertEqual(l6, [(1, "A")]) + l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), Error) + self.assertEqual(l7, [(1, "A"), (0, "B")]) + l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError) + self.assertEqual(l8, [(2, "A"), (1, "B")]) + + def test_map_strict_error_handling_stopiteration(self): + + class Iter: + def __init__(self, size): + self.size = size + def __iter__(self): + return self + def __next__(self): + self.size -= 1 + if self.size < 0: + raise StopIteration + return self.size + + l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), ValueError) + self.assertEqual(l1, [("A", 0)]) + l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError) + self.assertEqual(l2, [("A", 1, "A")]) + l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), ValueError) + self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")]) + l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError) + self.assertEqual(l4, [("A", 2), ("B", 1)]) + l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), ValueError) + self.assertEqual(l5, [(0, "A")]) + l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError) + self.assertEqual(l6, [(1, "A")]) + l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), ValueError) + self.assertEqual(l7, [(1, "A"), (0, "B")]) + l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError) + self.assertEqual(l8, [(2, "A"), (1, "B")]) + def test_max(self): self.assertEqual(max('123123'), '3') self.assertEqual(max(1, 2, 3), 3) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 8469de998ba..a52e1d3fa14 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -2433,10 +2433,10 @@ class SubclassWithKwargsTest(unittest.TestCase): subclass(*args, newarg=3) for cls, args, result in testcases: - # Constructors of repeat, zip, compress accept keyword arguments. + # Constructors of repeat, zip, map, compress accept keyword arguments. # Their subclasses need overriding __new__ to support new # keyword arguments. - if cls in [repeat, zip, compress]: + if cls in [repeat, zip, map, compress]: continue with self.subTest(cls): class subclass_with_init(cls): diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst b/Misc/NEWS.d/next/Core_and_Builtins/2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst new file mode 100644 index 00000000000..976d6712e4b --- /dev/null +++ b/Misc/NEWS.d/next/Core_and_Builtins/2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst @@ -0,0 +1,3 @@ +The :func:`map` built-in now has an optional keyword-only *strict* flag +like :func:`zip` to check that all the iterables are of equal length. +Patch by Wannes Boeykens. diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index 12f065d4b4f..85a28de2bb9 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -1311,6 +1311,7 @@ typedef struct { PyObject_HEAD PyObject *iters; PyObject *func; + int strict; } mapobject; static PyObject * @@ -1319,10 +1320,21 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds) PyObject *it, *iters, *func; mapobject *lz; Py_ssize_t numargs, i; + int strict = 0; - if ((type == &PyMap_Type || type->tp_init == PyMap_Type.tp_init) && - !_PyArg_NoKeywords("map", kwds)) - return NULL; + if (kwds) { + PyObject *empty = PyTuple_New(0); + if (empty == NULL) { + return NULL; + } + static char *kwlist[] = {"strict", NULL}; + int parsed = PyArg_ParseTupleAndKeywords( + empty, kwds, "|$p:map", kwlist, &strict); + Py_DECREF(empty); + if (!parsed) { + return NULL; + } + } numargs = PyTuple_Size(args); if (numargs < 2) { @@ -1354,6 +1366,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds) lz->iters = iters; func = PyTuple_GET_ITEM(args, 0); lz->func = Py_NewRef(func); + lz->strict = strict; return (PyObject *)lz; } @@ -1363,11 +1376,14 @@ map_vectorcall(PyObject *type, PyObject * const*args, size_t nargsf, PyObject *kwnames) { PyTypeObject *tp = _PyType_CAST(type); - if (tp == &PyMap_Type && !_PyArg_NoKwnames("map", kwnames)) { - return NULL; - } Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); + if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) { + // Fallback to map_new() + PyThreadState *tstate = _PyThreadState_GET(); + return _PyObject_MakeTpCall(tstate, type, args, nargs, kwnames); + } + if (nargs < 2) { PyErr_SetString(PyExc_TypeError, "map() must have at least two arguments."); @@ -1395,6 +1411,7 @@ map_vectorcall(PyObject *type, PyObject * const*args, } lz->iters = iters; lz->func = Py_NewRef(args[0]); + lz->strict = 0; return (PyObject *)lz; } @@ -1419,6 +1436,7 @@ map_traverse(mapobject *lz, visitproc visit, void *arg) static PyObject * map_next(mapobject *lz) { + Py_ssize_t i; PyObject *small_stack[_PY_FASTCALL_SMALL_STACK]; PyObject **stack; PyObject *result = NULL; @@ -1437,10 +1455,13 @@ map_next(mapobject *lz) } Py_ssize_t nargs = 0; - for (Py_ssize_t i=0; i < niters; i++) { + for (i=0; i < niters; i++) { PyObject *it = PyTuple_GET_ITEM(lz->iters, i); PyObject *val = Py_TYPE(it)->tp_iternext(it); if (val == NULL) { + if (lz->strict) { + goto check; + } goto exit; } stack[i] = val; @@ -1450,13 +1471,50 @@ map_next(mapobject *lz) result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL); exit: - for (Py_ssize_t i=0; i < nargs; i++) { + for (i=0; i < nargs; i++) { Py_DECREF(stack[i]); } if (stack != small_stack) { PyMem_Free(stack); } return result; +check: + if (PyErr_Occurred()) { + if (!PyErr_ExceptionMatches(PyExc_StopIteration)) { + // next() on argument i raised an exception (not StopIteration) + return NULL; + } + PyErr_Clear(); + } + if (i) { + // ValueError: map() argument 2 is shorter than argument 1 + // ValueError: map() argument 3 is shorter than arguments 1-2 + const char* plural = i == 1 ? " " : "s 1-"; + return PyErr_Format(PyExc_ValueError, + "map() argument %d is shorter than argument%s%d", + i + 1, plural, i); + } + for (i = 1; i < niters; i++) { + PyObject *it = PyTuple_GET_ITEM(lz->iters, i); + PyObject *val = (*Py_TYPE(it)->tp_iternext)(it); + if (val) { + Py_DECREF(val); + const char* plural = i == 1 ? " " : "s 1-"; + return PyErr_Format(PyExc_ValueError, + "map() argument %d is longer than argument%s%d", + i + 1, plural, i); + } + if (PyErr_Occurred()) { + if (!PyErr_ExceptionMatches(PyExc_StopIteration)) { + // next() on argument i raised an exception (not StopIteration) + return NULL; + } + PyErr_Clear(); + } + // Argument i is exhausted. So far so good... + } + // All arguments are exhausted. Success! + goto exit; } static PyObject * @@ -1473,21 +1531,41 @@ map_reduce(mapobject *lz, PyObject *Py_UNUSED(ignored)) PyTuple_SET_ITEM(args, i+1, Py_NewRef(it)); } + if (lz->strict) { + return Py_BuildValue("ONO", Py_TYPE(lz), args, Py_True); + } return Py_BuildValue("ON", Py_TYPE(lz), args); } +PyDoc_STRVAR(setstate_doc, "Set state information for unpickling."); + +static PyObject * +map_setstate(mapobject *lz, PyObject *state) +{ + int strict = PyObject_IsTrue(state); + if (strict < 0) { + return NULL; + } + lz->strict = strict; + Py_RETURN_NONE; +} + static PyMethodDef map_methods[] = { {"__reduce__", _PyCFunction_CAST(map_reduce), METH_NOARGS, reduce_doc}, + {"__setstate__", _PyCFunction_CAST(map_setstate), METH_O, setstate_doc}, {NULL, NULL} /* sentinel */ }; PyDoc_STRVAR(map_doc, -"map(function, iterable, /, *iterables)\n\ +"map(function, iterable, /, *iterables, strict=False)\n\ --\n\ \n\ Make an iterator that computes the function using arguments from\n\ -each of the iterables. Stops when the shortest iterable is exhausted."); +each of the iterables. Stops when the shortest iterable is exhausted.\n\ +\n\ +If strict is true and one of the arguments is exhausted before the others,\n\ +raise a ValueError."); PyTypeObject PyMap_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) @@ -3068,8 +3146,6 @@ zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored)) return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple); } -PyDoc_STRVAR(setstate_doc, "Set state information for unpickling."); - static PyObject * zip_setstate(zipobject *lz, PyObject *state) {