diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index c5788fca2c3..142fa04e388 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -1694,8 +1694,6 @@ class TestExamples(unittest.TestCase): def test_accumulate(self): self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_accumulate_reducible(self): # check copy, deepcopy, pickle data = [1, 2, 3, 4, 5] diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index cc5cd35df7c..91dbd7489db 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1092,7 +1092,54 @@ mod decl { } #[pyclass(with(IterNext, Iterable, Constructor))] - impl PyItertoolsAccumulate {} + impl PyItertoolsAccumulate { + #[pymethod(magic)] + fn setstate(zelf: PyRef, state: PyObjectRef, _vm: &VirtualMachine) -> PyResult<()> { + *zelf.acc_value.write() = Some(state); + Ok(()) + } + + #[pymethod(magic)] + fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyTupleRef { + let class = zelf.class().to_owned(); + let binop = zelf.binop.clone(); + let it = zelf.iterable.clone(); + let acc_value = zelf.acc_value.read().clone(); + if let Some(initial) = &zelf.initial { + let chain_args = PyList::from(vec![initial.clone(), it.to_pyobject(vm)]); + let chain = PyItertoolsChain { + source: PyRwLock::new(Some(chain_args.to_pyobject(vm).get_iter(vm).unwrap())), + active: PyRwLock::new(None), + }; + let tup = vm.new_tuple((chain, binop)); + return vm.new_tuple((class, tup, acc_value)); + } + match acc_value { + Some(obj) if obj.is(&vm.ctx.none) => { + let chain_args = PyList::from(vec![]); + let chain = PyItertoolsChain { + source: PyRwLock::new(Some( + chain_args.to_pyobject(vm).get_iter(vm).unwrap(), + )), + active: PyRwLock::new(None), + } + .into_pyobject(vm); + let acc = Self { + iterable: PyIter::new(chain), + binop, + initial: None, + acc_value: PyRwLock::new(None), + }; + let tup = vm.new_tuple((acc, 1, None::)); + let islice_cls = PyItertoolsIslice::class(&vm.ctx).to_owned(); + return vm.new_tuple((islice_cls, tup)); + } + _ => {} + } + let tup = vm.new_tuple((it, binop)); + vm.new_tuple((class, tup, acc_value)) + } + } impl SelfIter for PyItertoolsAccumulate {} impl IterNext for PyItertoolsAccumulate {