diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 23ede28c8eb..cd5b49d16d3 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -500,8 +500,6 @@ def test_combinatorics(self): self.assertEqual(comb, list(filter(set(perm).__contains__, cwr))) # comb: cwr that is a perm self.assertEqual(comb, sorted(set(cwr) & set(perm))) # comb: both a cwr and a perm - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_compress(self): self.assertEqual(list(compress(data='ABCDEF', selectors=[1,0,1,0,1,1])), list('ACEF')) self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF')) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 26f2710920f..fc7e7434e49 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -116,15 +116,15 @@ mod decl { #[derive(Debug, PyPayload)] struct PyItertoolsCompress { data: PyIter, - selector: PyIter, + selectors: PyIter, } #[derive(FromArgs)] struct CompressNewArgs { - #[pyarg(positional)] + #[pyarg(any)] data: PyIter, - #[pyarg(positional)] - selector: PyIter, + #[pyarg(any)] + selectors: PyIter, } impl Constructor for PyItertoolsCompress { @@ -132,23 +132,31 @@ mod decl { fn py_new( cls: PyTypeRef, - Self::Args { data, selector }: Self::Args, + Self::Args { data, selectors }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - PyItertoolsCompress { data, selector } + PyItertoolsCompress { data, selectors } .into_ref_with_type(vm, cls) .map(Into::into) } } #[pyimpl(with(IterNext, Constructor))] - impl PyItertoolsCompress {} + impl PyItertoolsCompress { + #[pymethod(magic)] + fn reduce(zelf: PyRef) -> (PyTypeRef, (PyIter, PyIter)) { + ( + zelf.class().clone(), + (zelf.data.clone(), zelf.selectors.clone()), + ) + } + } impl IterNextIterable for PyItertoolsCompress {} impl IterNext for PyItertoolsCompress { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { loop { - let sel_obj = match zelf.selector.next(vm)? { + let sel_obj = match zelf.selectors.next(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), };