diff --git a/extra_tests/snippets/stdlib_itertools.py b/extra_tests/snippets/stdlib_itertools.py index 05376d4b05a..a5b91d0cde9 100644 --- a/extra_tests/snippets/stdlib_itertools.py +++ b/extra_tests/snippets/stdlib_itertools.py @@ -50,6 +50,18 @@ with assert_raises(TypeError): next(x) +# iterables are lazily evaluted +x = chain.from_iterable(itertools.repeat(range(2))) +assert next(x) == 0 +assert next(x) == 1 +assert next(x) == 0 +assert next(x) == 1 + +x = chain(1, [2]) +with assert_raises(TypeError): + next(x) +with assert_raises(StopIteration): + next(x) # itertools.count tests diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index cc4eeb2d00c..26f2710920f 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -7,7 +7,7 @@ mod decl { rc::PyRc, }; use crate::{ - builtins::{int, PyGenericAlias, PyInt, PyIntRef, PyTuple, PyTupleRef, PyTypeRef}, + builtins::{int, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, PyTypeRef}, convert::ToPyObject, function::{ArgCallable, FuncArgs, OptionalArg, OptionalOption, PosArgs}, identifier, @@ -25,19 +25,18 @@ mod decl { #[pyclass(name = "chain")] #[derive(Debug, PyPayload)] struct PyItertoolsChain { - iterables: Vec, - cur_idx: AtomicCell, - cached_iter: PyRwLock>, + source: PyRwLock>, + active: PyRwLock>, } #[pyimpl(with(IterNext))] impl PyItertoolsChain { #[pyslot] fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + let args_list = PyList::from(args.args); PyItertoolsChain { - iterables: args.args, - cur_idx: AtomicCell::new(0), - cached_iter: PyRwLock::new(None), + source: PyRwLock::new(Some(args_list.to_pyobject(vm).get_iter(vm)?)), + active: PyRwLock::new(None), } .into_ref_with_type(vm, cls) .map(Into::into) @@ -46,13 +45,12 @@ mod decl { #[pyclassmethod] fn from_iterable( cls: PyTypeRef, - iterable: PyObjectRef, + source: PyObjectRef, vm: &VirtualMachine, ) -> PyResult> { PyItertoolsChain { - iterables: iterable.try_to_value(vm)?, - cur_idx: AtomicCell::new(0), - cached_iter: PyRwLock::new(None), + source: PyRwLock::new(Some(source.get_iter(vm)?)), + active: PyRwLock::new(None), } .into_ref_with_type(vm, cls) } @@ -65,37 +63,51 @@ mod decl { impl IterNextIterable for PyItertoolsChain {} impl IterNext for PyItertoolsChain { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { - loop { - let pos = zelf.cur_idx.load(); - if pos >= zelf.iterables.len() { - break; - } - let cur_iter = if zelf.cached_iter.read().is_none() { - // We need to call "get_iter" outside of the lock. - let iter = zelf.iterables[pos].clone().get_iter(vm)?; - *zelf.cached_iter.write() = Some(iter.clone()); - iter - } else if let Some(cached_iter) = (*zelf.cached_iter.read()).clone() { - cached_iter - } else { - // Someone changed cached iter to None since we checked. - continue; - }; - - // We need to call "next" outside of the lock. - match cur_iter.next(vm) { - Ok(PyIterReturn::Return(ok)) => return Ok(PyIterReturn::Return(ok)), - Ok(PyIterReturn::StopIteration(_)) => { - zelf.cur_idx.fetch_add(1); - *zelf.cached_iter.write() = None; + let source = if let Some(source) = zelf.source.read().clone() { + source + } else { + return Ok(PyIterReturn::StopIteration(None)); + }; + let next = loop { + let maybe_active = zelf.active.read().clone(); + if let Some(active) = maybe_active { + match active.next(vm) { + Ok(PyIterReturn::Return(ok)) => { + break Ok(PyIterReturn::Return(ok)); + } + Ok(PyIterReturn::StopIteration(_)) => { + *zelf.active.write() = None; + } + Err(err) => { + break Err(err); + } } - Err(err) => { - return Err(err); + } else { + match source.next(vm) { + Ok(PyIterReturn::Return(ok)) => match ok.get_iter(vm) { + Ok(iter) => { + *zelf.active.write() = Some(iter); + } + Err(err) => { + break Err(err); + } + }, + Ok(PyIterReturn::StopIteration(_)) => { + break Ok(PyIterReturn::StopIteration(None)); + } + Err(err) => { + break Err(err); + } } } - } - - Ok(PyIterReturn::StopIteration(None)) + }; + match next { + Err(_) | Ok(PyIterReturn::StopIteration(_)) => { + *zelf.source.write() = None; + } + _ => {} + }; + next } }