diff --git a/embassy-executor/src/raw/mod.rs b/embassy-executor/src/raw/mod.rs index e439415595..c6c0f9dc4a 100644 --- a/embassy-executor/src/raw/mod.rs +++ b/embassy-executor/src/raw/mod.rs @@ -22,6 +22,8 @@ pub(crate) mod util; #[cfg_attr(feature = "turbowakers", path = "waker_turbo.rs")] mod waker; +#[cfg(feature = "integrated-timers")] +use core::cell::Cell; use core::future::Future; use core::marker::PhantomData; use core::mem; @@ -29,8 +31,6 @@ use core::pin::Pin; use core::ptr::NonNull; use core::task::{Context, Poll}; -#[cfg(feature = "integrated-timers")] -use core::cell::Cell; #[cfg(feature = "integrated-timers")] use critical_section::Mutex; #[cfg(feature = "integrated-timers")] @@ -54,6 +54,8 @@ pub(crate) struct TaskHeader { #[cfg(feature = "integrated-timers")] pub(crate) expires_at: Mutex>, #[cfg(feature = "integrated-timers")] + pub(crate) next_expiration: SyncUnsafeCell, + #[cfg(feature = "integrated-timers")] pub(crate) timer_queue_item: timer_queue::TimerQueueItem, } @@ -127,6 +129,8 @@ impl TaskStorage { #[cfg(feature = "integrated-timers")] expires_at: Mutex::new(Cell::new(0)), #[cfg(feature = "integrated-timers")] + next_expiration: SyncUnsafeCell::new(0), + #[cfg(feature = "integrated-timers")] timer_queue_item: timer_queue::TimerQueueItem::new(), }, future: UninitCell::uninit(), @@ -166,9 +170,7 @@ impl TaskStorage { this.raw.state.despawn(); #[cfg(feature = "integrated-timers")] - critical_section::with(|cs| { - this.raw.expires_at.borrow(cs).set(u64::MAX); - }); + this.raw.next_expiration.set(u64::MAX); } Poll::Pending => {} } @@ -391,16 +393,13 @@ impl SyncExecutor { /// /// Same as [`Executor::poll`], plus you must only call this on the thread this executor was created. pub(crate) unsafe fn poll(&'static self) { - //trace!("poll"); #[allow(clippy::never_loop)] loop { self.run_queue.dequeue_all(|p| { let task = p.header(); #[cfg(feature = "integrated-timers")] - critical_section::with(|cs| { - task.expires_at.borrow(cs).set(u64::MAX); - }); + task.next_expiration.set(u64::MAX); if !task.state.run_dequeue() { // If task is not running, ignore it. This can happen in the following scenario: @@ -600,10 +599,10 @@ impl embassy_time_queue_driver::TimerQueue for TimerQueue { fn schedule_wake(&'static self, at: u64, waker: &core::task::Waker) { let task = waker::task_from_waker(waker); let task = task.header(); - critical_section::with(|cs| { - let expires_at = task.expires_at.borrow(cs).get(); - task.expires_at.borrow(cs).set(expires_at.min(at)); - }); + unsafe { + let expires_at = task.next_expiration.get(); + task.next_expiration.set(expires_at.min(at)); + } } } diff --git a/embassy-executor/src/raw/timer_queue.rs b/embassy-executor/src/raw/timer_queue.rs index 0be2379207..1716b8d9f4 100644 --- a/embassy-executor/src/raw/timer_queue.rs +++ b/embassy-executor/src/raw/timer_queue.rs @@ -1,11 +1,12 @@ use core::cell::Cell; use core::cmp::min; -use super::TaskRef; use critical_section::{CriticalSection, Mutex}; +use super::TaskRef; + pub(crate) struct TimerQueueItem { - pub(super) next: Mutex>>, + next: Mutex>>, } impl TimerQueueItem { @@ -17,7 +18,7 @@ impl TimerQueueItem { } pub(crate) struct TimerQueue { - pub(super) head: Mutex>>, + head: Mutex>>, } impl TimerQueue { @@ -30,7 +31,8 @@ impl TimerQueue { pub(crate) unsafe fn update(&self, p: TaskRef) { let task = p.header(); critical_section::with(|cs| { - if task.expires_at.borrow(cs).get() != u64::MAX { + task.expires_at.borrow(cs).set(task.next_expiration.get()); + if task.next_expiration.get() != u64::MAX { if task.state.timer_enqueue() { let prev = self.head.borrow(cs).replace(Some(p)); task.timer_queue_item.next.borrow(cs).set(prev);