0
0
mirror of https://github.com/python/cpython.git synced 2024-11-21 12:59:38 +01:00

gh-119793: Add optional length-checking to map() (GH-120471)

Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
Co-authored-by: Erlend E. Aasland <erlend.aasland@protonmail.com>
Co-authored-by: Raymond Hettinger <rhettinger@users.noreply.github.com>
This commit is contained in:
Nice Zombies 2024-11-04 15:00:19 +01:00 committed by GitHub
parent bfc1d2504c
commit 3032fcd90e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 210 additions and 17 deletions

View File

@ -1205,14 +1205,19 @@ are always available. They are listed here in alphabetical order.
unchanged from previous versions.
.. function:: map(function, iterable, *iterables)
.. function:: map(function, iterable, /, *iterables, strict=False)
Return an iterator that applies *function* to every item of *iterable*,
yielding the results. If additional *iterables* arguments are passed,
*function* must take that many arguments and is applied to the items from all
iterables in parallel. With multiple iterables, the iterator stops when the
shortest iterable is exhausted. For cases where the function inputs are
already arranged into argument tuples, see :func:`itertools.starmap`\.
shortest iterable is exhausted. If *strict* is ``True`` and one of the
iterables is exhausted before the others, a :exc:`ValueError` is raised. For
cases where the function inputs are already arranged into argument tuples,
see :func:`itertools.starmap`.
.. versionchanged:: 3.14
Added the *strict* parameter.
.. function:: max(iterable, *, key=None)

View File

@ -175,6 +175,10 @@ Improved error messages
Other language changes
======================
* The :func:`map` built-in now has an optional keyword-only *strict* flag
like :func:`zip` to check that all the iterables are of equal length.
(Contributed by Wannes Boeykens in :gh:`119793`.)
* Incorrect usage of :keyword:`await` and asynchronous comprehensions
is now detected even if the code is optimized away by the :option:`-O`
command-line option. For example, ``python -O -c 'assert await 1'``

View File

@ -148,6 +148,9 @@ def filter_char(arg):
def map_char(arg):
return chr(ord(arg)+1)
def pack(*args):
return args
class BuiltinTest(unittest.TestCase):
# Helper to check picklability
def check_iter_pickle(self, it, seq, proto):
@ -1269,6 +1272,108 @@ class BuiltinTest(unittest.TestCase):
m2 = map(map_char, "Is this the real life?")
self.check_iter_pickle(m1, list(m2), proto)
# strict map tests based on strict zip tests
def test_map_pickle_strict(self):
a = (1, 2, 3)
b = (4, 5, 6)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
self.check_iter_pickle(m1, t, proto)
def test_map_pickle_strict_fail(self):
a = (1, 2, 3)
b = (4, 5, 6, 7)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
m2 = pickle.loads(pickle.dumps(m1, proto))
self.assertEqual(self.iter_error(m1, ValueError), t)
self.assertEqual(self.iter_error(m2, ValueError), t)
def test_map_strict(self):
self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)),
((1, 'a'), (2, 'b'), (3, 'c')))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2, 3, 4), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), (1, 2), 'abc', strict=True))
def test_map_strict_iterators(self):
x = iter(range(5))
y = [0]
z = iter(range(5))
self.assertRaises(ValueError, list,
(map(pack, x, y, z, strict=True)))
self.assertEqual(next(x), 2)
self.assertEqual(next(z), 1)
def test_map_strict_error_handling(self):
class Error(Exception):
pass
class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise Error
return self.size
l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), Error)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), Error)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), Error)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), Error)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])
def test_map_strict_error_handling_stopiteration(self):
class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise StopIteration
return self.size
l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), ValueError)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), ValueError)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])
def test_max(self):
self.assertEqual(max('123123'), '3')
self.assertEqual(max(1, 2, 3), 3)

View File

@ -2433,10 +2433,10 @@ class SubclassWithKwargsTest(unittest.TestCase):
subclass(*args, newarg=3)
for cls, args, result in testcases:
# Constructors of repeat, zip, compress accept keyword arguments.
# Constructors of repeat, zip, map, compress accept keyword arguments.
# Their subclasses need overriding __new__ to support new
# keyword arguments.
if cls in [repeat, zip, compress]:
if cls in [repeat, zip, map, compress]:
continue
with self.subTest(cls):
class subclass_with_init(cls):

View File

@ -0,0 +1,3 @@
The :func:`map` built-in now has an optional keyword-only *strict* flag
like :func:`zip` to check that all the iterables are of equal length.
Patch by Wannes Boeykens.

View File

@ -1311,6 +1311,7 @@ typedef struct {
PyObject_HEAD
PyObject *iters;
PyObject *func;
int strict;
} mapobject;
static PyObject *
@ -1319,10 +1320,21 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyObject *it, *iters, *func;
mapobject *lz;
Py_ssize_t numargs, i;
int strict = 0;
if ((type == &PyMap_Type || type->tp_init == PyMap_Type.tp_init) &&
!_PyArg_NoKeywords("map", kwds))
return NULL;
if (kwds) {
PyObject *empty = PyTuple_New(0);
if (empty == NULL) {
return NULL;
}
static char *kwlist[] = {"strict", NULL};
int parsed = PyArg_ParseTupleAndKeywords(
empty, kwds, "|$p:map", kwlist, &strict);
Py_DECREF(empty);
if (!parsed) {
return NULL;
}
}
numargs = PyTuple_Size(args);
if (numargs < 2) {
@ -1354,6 +1366,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
lz->iters = iters;
func = PyTuple_GET_ITEM(args, 0);
lz->func = Py_NewRef(func);
lz->strict = strict;
return (PyObject *)lz;
}
@ -1363,11 +1376,14 @@ map_vectorcall(PyObject *type, PyObject * const*args,
size_t nargsf, PyObject *kwnames)
{
PyTypeObject *tp = _PyType_CAST(type);
if (tp == &PyMap_Type && !_PyArg_NoKwnames("map", kwnames)) {
return NULL;
}
Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) {
// Fallback to map_new()
PyThreadState *tstate = _PyThreadState_GET();
return _PyObject_MakeTpCall(tstate, type, args, nargs, kwnames);
}
if (nargs < 2) {
PyErr_SetString(PyExc_TypeError,
"map() must have at least two arguments.");
@ -1395,6 +1411,7 @@ map_vectorcall(PyObject *type, PyObject * const*args,
}
lz->iters = iters;
lz->func = Py_NewRef(args[0]);
lz->strict = 0;
return (PyObject *)lz;
}
@ -1419,6 +1436,7 @@ map_traverse(mapobject *lz, visitproc visit, void *arg)
static PyObject *
map_next(mapobject *lz)
{
Py_ssize_t i;
PyObject *small_stack[_PY_FASTCALL_SMALL_STACK];
PyObject **stack;
PyObject *result = NULL;
@ -1437,10 +1455,13 @@ map_next(mapobject *lz)
}
Py_ssize_t nargs = 0;
for (Py_ssize_t i=0; i < niters; i++) {
for (i=0; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = Py_TYPE(it)->tp_iternext(it);
if (val == NULL) {
if (lz->strict) {
goto check;
}
goto exit;
}
stack[i] = val;
@ -1450,13 +1471,50 @@ map_next(mapobject *lz)
result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL);
exit:
for (Py_ssize_t i=0; i < nargs; i++) {
for (i=0; i < nargs; i++) {
Py_DECREF(stack[i]);
}
if (stack != small_stack) {
PyMem_Free(stack);
}
return result;
check:
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
if (i) {
// ValueError: map() argument 2 is shorter than argument 1
// ValueError: map() argument 3 is shorter than arguments 1-2
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is shorter than argument%s%d",
i + 1, plural, i);
}
for (i = 1; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = (*Py_TYPE(it)->tp_iternext)(it);
if (val) {
Py_DECREF(val);
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is longer than argument%s%d",
i + 1, plural, i);
}
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
// Argument i is exhausted. So far so good...
}
// All arguments are exhausted. Success!
goto exit;
}
static PyObject *
@ -1473,21 +1531,41 @@ map_reduce(mapobject *lz, PyObject *Py_UNUSED(ignored))
PyTuple_SET_ITEM(args, i+1, Py_NewRef(it));
}
if (lz->strict) {
return Py_BuildValue("ONO", Py_TYPE(lz), args, Py_True);
}
return Py_BuildValue("ON", Py_TYPE(lz), args);
}
PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
static PyObject *
map_setstate(mapobject *lz, PyObject *state)
{
int strict = PyObject_IsTrue(state);
if (strict < 0) {
return NULL;
}
lz->strict = strict;
Py_RETURN_NONE;
}
static PyMethodDef map_methods[] = {
{"__reduce__", _PyCFunction_CAST(map_reduce), METH_NOARGS, reduce_doc},
{"__setstate__", _PyCFunction_CAST(map_setstate), METH_O, setstate_doc},
{NULL, NULL} /* sentinel */
};
PyDoc_STRVAR(map_doc,
"map(function, iterable, /, *iterables)\n\
"map(function, iterable, /, *iterables, strict=False)\n\
--\n\
\n\
Make an iterator that computes the function using arguments from\n\
each of the iterables. Stops when the shortest iterable is exhausted.");
each of the iterables. Stops when the shortest iterable is exhausted.\n\
\n\
If strict is true and one of the arguments is exhausted before the others,\n\
raise a ValueError.");
PyTypeObject PyMap_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
@ -3068,8 +3146,6 @@ zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored))
return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple);
}
PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
static PyObject *
zip_setstate(zipobject *lz, PyObject *state)
{