Skip to content

Commit 3d40cca

Browse files
committed
rt: add optional per-task user data exposed to hooks
In tokio-rs#7197 and tokio-rs#7306, improved capabilities of task hooks were discussed and an initial implementation provided. However, that involved quite wide-reaching changes, modifying every spawn site and introducing a global map to provide the full inheritance capabilities originally proposed. This is the first part of a more basic version where we only use the existing hooks and provide the capabilities for consumers to be able to implement more complex relationships if needed, just adding an optional user data ref to the task header. The additional data is 2*usize, and is not enough to result in the struct requiring more than one cache line. A user is now able to use their own global map to build inheritance capabilities if needed, and this would be made simpler by also exposing the current task user data to the on_task_spawn hook, which a followup will look to do.
1 parent 8ccf2fb commit 3d40cca

File tree

17 files changed

+386
-32
lines changed

17 files changed

+386
-32
lines changed

tokio/src/runtime/builder.rs

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
use crate::runtime::handle::Handle;
44
use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback};
55
#[cfg(tokio_unstable)]
6-
use crate::runtime::{metrics::HistogramConfiguration, LocalOptions, LocalRuntime, TaskMeta};
6+
use crate::runtime::{
7+
metrics::HistogramConfiguration, LocalOptions, LocalRuntime, TaskMeta, TaskSpawnCallback,
8+
UserData,
9+
};
710
use crate::util::rand::{RngSeed, RngSeedGenerator};
811

912
use crate::runtime::blocking::BlockingPool;
@@ -89,6 +92,9 @@ pub struct Builder {
8992
pub(super) after_unpark: Option<Callback>,
9093

9194
/// To run before each task is spawned.
95+
#[cfg(tokio_unstable)]
96+
pub(super) before_spawn: Option<TaskSpawnCallback>,
97+
#[cfg(not(tokio_unstable))]
9298
pub(super) before_spawn: Option<TaskCallback>,
9399

94100
/// To run before each poll
@@ -731,8 +737,15 @@ impl Builder {
731737
/// Executes function `f` just before a task is spawned.
732738
///
733739
/// `f` is called within the Tokio context, so functions like
734-
/// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being
735-
/// invoked immediately.
740+
/// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback
741+
/// being invoked immediately.
742+
///
743+
/// `f` must return an `Option<&'static dyn Any>`. A value returned by this callback
744+
/// is attached to the task and can be retrieved using [`TaskMeta::get_data`] in
745+
/// subsequent calls to other hooks for this task such as
746+
/// [`on_before_task_poll`](crate::runtime::Builder::on_before_task_poll),
747+
/// [`on_after_task_poll`](crate::runtime::Builder::on_after_task_poll), and
748+
/// [`on_task_terminate`](crate::runtime::Builder::on_task_terminate).
736749
///
737750
/// This can be used for bookkeeping or monitoring purposes.
738751
///
@@ -755,6 +768,7 @@ impl Builder {
755768
/// let runtime = runtime::Builder::new_current_thread()
756769
/// .on_task_spawn(|_| {
757770
/// println!("spawning task");
771+
/// None::<()>
758772
/// })
759773
/// .build()
760774
/// .unwrap();
@@ -768,13 +782,70 @@ impl Builder {
768782
/// })
769783
/// # }
770784
/// ```
785+
///
786+
/// ```
787+
/// # #[cfg(not(target_family = "wasm"))]
788+
/// # use tokio::runtime;
789+
/// # use std::sync::atomic::{AtomicUsize, Ordering};
790+
/// # pub fn main() {
791+
/// struct YieldingTaskMetadata {
792+
/// pub yield_count: AtomicUsize,
793+
/// }
794+
/// let runtime = runtime::Builder::new_current_thread()
795+
/// .on_task_spawn(|meta| {
796+
/// println!("spawning task {}", meta.id());
797+
/// Some(YieldingTaskMetadata { yield_count: AtomicUsize::new(0) })
798+
/// })
799+
/// .on_after_task_poll(|meta| {
800+
/// if let Some(data) = meta.get_data::<YieldingTaskMetadata>() {
801+
/// println!("task {} yield count: {}", meta.id(), data.yield_count.fetch_add(1, Ordering::Relaxed));
802+
/// }
803+
/// })
804+
/// .on_task_terminate(|meta| {
805+
/// match meta.get_data::<YieldingTaskMetadata>() {
806+
/// Some(data) => {
807+
/// let yield_count = data.yield_count.load(Ordering::Relaxed);
808+
/// println!("task {} total yield count: {}", meta.id(), yield_count);
809+
/// assert!(yield_count == 64);
810+
/// },
811+
/// None => panic!("task has missing or incorrect user data"),
812+
/// }
813+
/// })
814+
/// .build()
815+
/// .unwrap();
816+
///
817+
/// runtime.block_on(async {
818+
/// let _ = tokio::task::spawn(async {
819+
/// for _ in 0..64 {
820+
/// println!("yielding");
821+
/// tokio::task::yield_now().await;
822+
/// }
823+
/// }).await.unwrap();
824+
/// })
825+
/// # }
826+
/// ```
771827
#[cfg(all(not(loom), tokio_unstable))]
772828
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
773-
pub fn on_task_spawn<F>(&mut self, f: F) -> &mut Self
829+
pub fn on_task_spawn<F, T>(&mut self, f: F) -> &mut Self
774830
where
775-
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
831+
F: Fn(&TaskMeta<'_>) -> Option<T> + Send + Sync + 'static,
832+
T: 'static,
776833
{
777-
self.before_spawn = Some(std::sync::Arc::new(f));
834+
use std::any::Any;
835+
836+
fn wrap<F, T>(f: F) -> impl Fn(&TaskMeta<'_>) -> UserData + Send + Sync + 'static
837+
where
838+
F: Fn(&TaskMeta<'_>) -> Option<T> + Send + Sync + 'static,
839+
T: 'static,
840+
{
841+
move |meta| {
842+
f(meta).map(|value| {
843+
let boxed: Box<dyn Any> = Box::new(value);
844+
Box::leak(boxed) as &'static dyn Any
845+
})
846+
}
847+
}
848+
self.before_spawn = Some(std::sync::Arc::new(wrap(f)));
778849
self
779850
}
780851

tokio/src/runtime/config.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
allow(dead_code)
44
)]
55
use crate::runtime::{Callback, TaskCallback};
6+
#[cfg(tokio_unstable)]
7+
use crate::runtime::TaskSpawnCallback;
68
use crate::util::RngSeedGenerator;
79

810
pub(crate) struct Config {
@@ -19,6 +21,9 @@ pub(crate) struct Config {
1921
pub(crate) after_unpark: Option<Callback>,
2022

2123
/// To run before each task is spawned.
24+
#[cfg(tokio_unstable)]
25+
pub(crate) before_spawn: Option<TaskSpawnCallback>,
26+
#[cfg(not(tokio_unstable))]
2227
pub(crate) before_spawn: Option<TaskCallback>,
2328

2429
/// To run after each task is terminated.

tokio/src/runtime/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ cfg_rt! {
391391

392392
mod task_hooks;
393393
pub(crate) use task_hooks::{TaskHooks, TaskCallback};
394+
#[cfg(tokio_unstable)]
395+
pub(crate) use task_hooks::{TaskSpawnCallback, UserData};
394396
cfg_unstable! {
395397
pub use task_hooks::TaskMeta;
396398
}

tokio/src/runtime/scheduler/current_thread/mod.rs

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use crate::runtime::scheduler::{self, Defer, Inject};
55
use crate::runtime::task::{
66
self, JoinHandle, OwnedTasks, Schedule, SpawnLocation, Task, TaskHarnessScheduleHooks,
77
};
8+
#[cfg(tokio_unstable)]
9+
use crate::runtime::UserData;
810
use crate::runtime::{
911
blocking, context, Config, MetricsBatch, SchedulerMetrics, TaskHooks, TaskMeta, WorkerMetrics,
1012
};
@@ -456,13 +458,43 @@ impl Handle {
456458
F: crate::future::Future + Send + 'static,
457459
F::Output: Send + 'static,
458460
{
459-
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id, spawned_at);
461+
Self::spawn_with_user_data(me, future, id, spawned_at, #[cfg(tokio_unstable)] None)
462+
}
460463

461-
me.task_hooks.spawn(&TaskMeta {
464+
#[track_caller]
465+
pub(crate) fn spawn_with_user_data<F>(
466+
me: &Arc<Self>,
467+
future: F,
468+
id: crate::runtime::task::Id,
469+
spawned_at: SpawnLocation,
470+
#[cfg(tokio_unstable)]
471+
user_data: UserData,
472+
) -> JoinHandle<F::Output>
473+
where
474+
F: crate::future::Future + Send + 'static,
475+
F::Output: Send + 'static,
476+
{
477+
let task_meta = TaskMeta {
462478
id,
463479
spawned_at,
480+
#[cfg(tokio_unstable)]
481+
user_data,
464482
_phantom: Default::default(),
465-
});
483+
};
484+
485+
#[cfg(not(tokio_unstable))]
486+
{
487+
me.task_hooks.spawn(&task_meta);
488+
}
489+
490+
let (handle, notified) = me.shared.owned.bind(
491+
future,
492+
me.clone(),
493+
id,
494+
spawned_at,
495+
#[cfg(tokio_unstable)]
496+
me.task_hooks.spawn(&task_meta),
497+
);
466498

467499
if let Some(notified) = notified {
468500
me.schedule(notified);
@@ -488,16 +520,27 @@ impl Handle {
488520
F: crate::future::Future + 'static,
489521
F::Output: 'static,
490522
{
491-
let (handle, notified) = me
492-
.shared
493-
.owned
494-
.bind_local(future, me.clone(), id, spawned_at);
495-
496-
me.task_hooks.spawn(&TaskMeta {
523+
let task_meta = TaskMeta {
497524
id,
498525
spawned_at,
526+
#[cfg(tokio_unstable)]
527+
user_data: None,
499528
_phantom: Default::default(),
500-
});
529+
};
530+
531+
#[cfg(not(tokio_unstable))]
532+
{
533+
me.task_hooks.spawn(&task_meta);
534+
}
535+
536+
let (handle, notified) = me.shared.owned.bind_local(
537+
future,
538+
me.clone(),
539+
id,
540+
spawned_at,
541+
#[cfg(tokio_unstable)]
542+
me.task_hooks.spawn(&task_meta),
543+
);
501544

502545
if let Some(notified) = notified {
503546
me.schedule(notified);

tokio/src/runtime/scheduler/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ impl Handle {
6868
cfg_rt! {
6969
use crate::future::Future;
7070
use crate::loom::sync::Arc;
71+
#[cfg(tokio_unstable)]
72+
use crate::runtime::UserData;
7173
use crate::runtime::{blocking, task::{Id, SpawnLocation}};
7274
use crate::runtime::context;
7375
use crate::task::JoinHandle;
@@ -130,6 +132,20 @@ cfg_rt! {
130132
}
131133
}
132134

135+
#[cfg(tokio_unstable)]
136+
pub(crate) fn spawn_with_user_data<F>(&self, future: F, id: Id, spawned_at: SpawnLocation, user_data: UserData) -> JoinHandle<F::Output>
137+
where
138+
F: Future + Send + 'static,
139+
F::Output: Send + 'static,
140+
{
141+
match self {
142+
Handle::CurrentThread(h) => current_thread::Handle::spawn_with_user_data(h, future, id, spawned_at, user_data),
143+
144+
#[cfg(feature = "rt-multi-thread")]
145+
Handle::MultiThread(h) => multi_thread::Handle::spawn_with_user_data(h, future, id, spawned_at, user_data),
146+
}
147+
}
148+
133149
/// Spawn a local task
134150
///
135151
/// # Safety

tokio/src/runtime/scheduler/multi_thread/handle.rs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use crate::runtime::{
88
TaskHooks, TaskMeta,
99
};
1010
use crate::util::RngSeedGenerator;
11+
#[cfg(tokio_unstable)]
12+
use crate::runtime::task_hooks::UserData;
1113

1214
use std::fmt;
1315

@@ -47,7 +49,23 @@ impl Handle {
4749
F: crate::future::Future + Send + 'static,
4850
F::Output: Send + 'static,
4951
{
50-
Self::bind_new_task(me, future, id, spawned_at)
52+
Self::bind_new_task(me, future, id, spawned_at, #[cfg(tokio_unstable)] None)
53+
}
54+
55+
/// Spawns a future with user data onto the thread pool
56+
#[cfg(tokio_unstable)]
57+
pub(crate) fn spawn_with_user_data<F>(
58+
me: &Arc<Self>,
59+
future: F,
60+
id: task::Id,
61+
spawned_at: SpawnLocation,
62+
user_data: UserData,
63+
) -> JoinHandle<F::Output>
64+
where
65+
F: crate::future::Future + Send + 'static,
66+
F::Output: Send + 'static,
67+
{
68+
Self::bind_new_task(me, future, id, spawned_at, user_data)
5169
}
5270

5371
pub(crate) fn shutdown(&self) {
@@ -60,18 +78,34 @@ impl Handle {
6078
future: T,
6179
id: task::Id,
6280
spawned_at: SpawnLocation,
81+
#[cfg(tokio_unstable)]
82+
user_data: UserData,
6383
) -> JoinHandle<T::Output>
6484
where
6585
T: Future + Send + 'static,
6686
T::Output: Send + 'static,
6787
{
68-
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id, spawned_at);
69-
70-
me.task_hooks.spawn(&TaskMeta {
88+
let task_meta = TaskMeta {
7189
id,
7290
spawned_at,
91+
#[cfg(tokio_unstable)]
92+
user_data,
7393
_phantom: Default::default(),
74-
});
94+
};
95+
96+
#[cfg(not(tokio_unstable))]
97+
{
98+
me.task_hooks.spawn(&task_meta);
99+
}
100+
101+
let (handle, notified) = me.shared.owned.bind(
102+
future,
103+
me.clone(),
104+
id,
105+
spawned_at,
106+
#[cfg(tokio_unstable)]
107+
me.task_hooks.spawn(&task_meta),
108+
);
75109

76110
me.schedule_option_task_without_yield(notified);
77111

0 commit comments

Comments
 (0)