1//! Futures task based helpers to easily test futures and manually written futures.
2//!
3//! The [`Spawn`] type is used as a mock task harness that allows you to poll futures
4//! without needing to setup pinning or context. Any future can be polled but if the
5//! future requires the tokio async context you will need to ensure that you poll the
6//! [`Spawn`] within a tokio context, this means that as long as you are inside the
7//! runtime it will work and you can poll it via [`Spawn`].
8//!
9//! [`Spawn`] also supports [`Stream`] to call `poll_next` without pinning
10//! or context.
11//!
12//! In addition to circumventing the need for pinning and context, [`Spawn`] also tracks
13//! the amount of times the future/task was woken. This can be useful to track if some
14//! leaf future notified the root task correctly.
15//!
16//! # Example
17//!
18//! ```
19//! use tokio_test::task;
20//!
21//! let fut = async {};
22//!
23//! let mut task = task::spawn(fut);
24//!
25//! assert!(task.poll().is_ready(), "Task was not ready!");
26//! ```
2728use std::future::Future;
29use std::mem;
30use std::ops;
31use std::pin::Pin;
32use std::sync::{Arc, Condvar, Mutex};
33use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
3435use tokio_stream::Stream;
3637/// Spawn a future into a [`Spawn`] which wraps the future in a mocked executor.
38///
39/// This can be used to spawn a [`Future`] or a [`Stream`].
40///
41/// For more information, check the module docs.
42pub fn spawn<T>(task: T) -> Spawn<T> {
43 Spawn {
44 task: MockTask::new(),
45 future: Box::pin(task),
46 }
47}
4849/// Future spawned on a mock task that can be used to poll the future or stream
50/// without needing pinning or context types.
51#[derive(Debug)]
52#[must_use = "futures do nothing unless you `.await` or poll them"]
53pub struct Spawn<T> {
54 task: MockTask,
55 future: Pin<Box<T>>,
56}
5758#[derive(Debug, Clone)]
59struct MockTask {
60 waker: Arc<ThreadWaker>,
61}
6263#[derive(Debug)]
64struct ThreadWaker {
65 state: Mutex<usize>,
66 condvar: Condvar,
67}
6869const IDLE: usize = 0;
70const WAKE: usize = 1;
71const SLEEP: usize = 2;
7273impl<T> Spawn<T> {
74/// Consumes `self` returning the inner value
75pub fn into_inner(self) -> T
76where
77T: Unpin,
78 {
79*Pin::into_inner(self.future)
80 }
8182/// Returns `true` if the inner future has received a wake notification
83 /// since the last call to `enter`.
84pub fn is_woken(&self) -> bool {
85self.task.is_woken()
86 }
8788/// Returns the number of references to the task waker
89 ///
90 /// The task itself holds a reference. The return value will never be zero.
91pub fn waker_ref_count(&self) -> usize {
92self.task.waker_ref_count()
93 }
9495/// Enter the task context
96pub fn enter<F, R>(&mut self, f: F) -> R
97where
98F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
99 {
100let fut = self.future.as_mut();
101self.task.enter(|cx| f(cx, fut))
102 }
103}
104105impl<T: Unpin> ops::Deref for Spawn<T> {
106type Target = T;
107108fn deref(&self) -> &T {
109&self.future
110 }
111}
112113impl<T: Unpin> ops::DerefMut for Spawn<T> {
114fn deref_mut(&mut self) -> &mut T {
115&mut self.future
116 }
117}
118119impl<T: Future> Spawn<T> {
120/// If `T` is a [`Future`] then poll it. This will handle pinning and the context
121 /// type for the future.
122pub fn poll(&mut self) -> Poll<T::Output> {
123let fut = self.future.as_mut();
124self.task.enter(|cx| fut.poll(cx))
125 }
126}
127128impl<T: Stream> Spawn<T> {
129/// If `T` is a [`Stream`] then `poll_next` it. This will handle pinning and the context
130 /// type for the stream.
131pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
132let stream = self.future.as_mut();
133self.task.enter(|cx| stream.poll_next(cx))
134 }
135}
136137impl<T: Future> Future for Spawn<T> {
138type Output = T::Output;
139140fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141self.future.as_mut().poll(cx)
142 }
143}
144145impl<T: Stream> Stream for Spawn<T> {
146type Item = T::Item;
147148fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149self.future.as_mut().poll_next(cx)
150 }
151152fn size_hint(&self) -> (usize, Option<usize>) {
153self.future.size_hint()
154 }
155}
156157impl MockTask {
158/// Creates new mock task
159fn new() -> Self {
160 MockTask {
161 waker: Arc::new(ThreadWaker::new()),
162 }
163 }
164165/// Runs a closure from the context of the task.
166 ///
167 /// Any wake notifications resulting from the execution of the closure are
168 /// tracked.
169fn enter<F, R>(&mut self, f: F) -> R
170where
171F: FnOnce(&mut Context<'_>) -> R,
172 {
173self.waker.clear();
174let waker = self.waker();
175let mut cx = Context::from_waker(&waker);
176177 f(&mut cx)
178 }
179180/// Returns `true` if the inner future has received a wake notification
181 /// since the last call to `enter`.
182fn is_woken(&self) -> bool {
183self.waker.is_woken()
184 }
185186/// Returns the number of references to the task waker
187 ///
188 /// The task itself holds a reference. The return value will never be zero.
189fn waker_ref_count(&self) -> usize {
190 Arc::strong_count(&self.waker)
191 }
192193fn waker(&self) -> Waker {
194unsafe {
195let raw = to_raw(self.waker.clone());
196 Waker::from_raw(raw)
197 }
198 }
199}
200201impl Default for MockTask {
202fn default() -> Self {
203Self::new()
204 }
205}
206207impl ThreadWaker {
208fn new() -> Self {
209 ThreadWaker {
210 state: Mutex::new(IDLE),
211 condvar: Condvar::new(),
212 }
213 }
214215/// Clears any previously received wakes, avoiding potential spurious
216 /// wake notifications. This should only be called immediately before running the
217 /// task.
218fn clear(&self) {
219*self.state.lock().unwrap() = IDLE;
220 }
221222fn is_woken(&self) -> bool {
223match *self.state.lock().unwrap() {
224 IDLE => false,
225 WAKE => true,
226_ => unreachable!(),
227 }
228 }
229230fn wake(&self) {
231// First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
232let mut state = self.state.lock().unwrap();
233let prev = *state;
234235if prev == WAKE {
236return;
237 }
238239*state = WAKE;
240241if prev == IDLE {
242return;
243 }
244245// The other half is sleeping, so we wake it up.
246assert_eq!(prev, SLEEP);
247self.condvar.notify_one();
248 }
249}
250251static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
252253unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
254 RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
255}
256257unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
258 Arc::from_raw(raw as *const ThreadWaker)
259}
260261unsafe fn clone(raw: *const ()) -> RawWaker {
262let waker = from_raw(raw);
263264// Increment the ref count
265mem::forget(waker.clone());
266267 to_raw(waker)
268}
269270unsafe fn wake(raw: *const ()) {
271let waker = from_raw(raw);
272 waker.wake();
273}
274275unsafe fn wake_by_ref(raw: *const ()) {
276let waker = from_raw(raw);
277 waker.wake();
278279// We don't actually own a reference to the unparker
280mem::forget(waker);
281}
282283unsafe fn drop_waker(raw: *const ()) {
284let _ = from_raw(raw);
285}