Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrated-timers: Dequeue in alarm_callback #3579

Closed
wants to merge 13 commits into from
90 changes: 37 additions & 53 deletions embassy-executor/src/raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use core::pin::Pin;
use core::ptr::NonNull;
use core::task::{Context, Poll};

#[cfg(feature = "integrated-timers")]
use critical_section::Mutex;
#[cfg(feature = "integrated-timers")]
use embassy_time_driver::AlarmHandle;
#[cfg(feature = "rtos-trace")]
Expand All @@ -47,6 +49,8 @@ pub(crate) struct TaskHeader {
pub(crate) executor: SyncUnsafeCell<Option<&'static SyncExecutor>>,
poll_fn: SyncUnsafeCell<Option<unsafe fn(TaskRef)>>,

// The following fields are conceptually owned by the executor's timer queue and should not
// be accessed outside of that.
#[cfg(feature = "integrated-timers")]
pub(crate) expires_at: SyncUnsafeCell<u64>,
#[cfg(feature = "integrated-timers")]
Expand Down Expand Up @@ -162,7 +166,10 @@ impl<F: Future + 'static> TaskStorage<F> {
this.raw.state.despawn();

#[cfg(feature = "integrated-timers")]
this.raw.expires_at.set(u64::MAX);
critical_section::with(|cs| {
let executor = this.raw.executor.get().unwrap_unchecked();
executor.timer_queue.borrow(cs).notify_task_exited(p);
});
}
Poll::Pending => {}
}
Expand Down Expand Up @@ -318,7 +325,7 @@ pub(crate) struct SyncExecutor {
pender: Pender,

#[cfg(feature = "integrated-timers")]
pub(crate) timer_queue: timer_queue::TimerQueue,
pub(crate) timer_queue: Mutex<timer_queue::TimerQueue>,
#[cfg(feature = "integrated-timers")]
alarm: AlarmHandle,
}
Expand All @@ -333,7 +340,7 @@ impl SyncExecutor {
pender,

#[cfg(feature = "integrated-timers")]
timer_queue: timer_queue::TimerQueue::new(),
timer_queue: Mutex::new(timer_queue::TimerQueue::new(alarm)),
#[cfg(feature = "integrated-timers")]
alarm,
}
Expand Down Expand Up @@ -363,6 +370,11 @@ impl SyncExecutor {
#[cfg(feature = "integrated-timers")]
fn alarm_callback(ctx: *mut ()) {
let this: &Self = unsafe { &*(ctx as *const Self) };

critical_section::with(|cs| unsafe {
this.timer_queue.borrow(cs).dispatch(wake_task_no_pend);
});

this.pender.pend();
}

Expand All @@ -379,56 +391,27 @@ impl SyncExecutor {
///
/// Same as [`Executor::poll`], plus you must only call this on the thread this executor was created.
pub(crate) unsafe fn poll(&'static self) {
#[allow(clippy::never_loop)]
loop {
#[cfg(feature = "integrated-timers")]
self.timer_queue
.dequeue_expired(embassy_time_driver::now(), wake_task_no_pend);

self.run_queue.dequeue_all(|p| {
let task = p.header();

#[cfg(feature = "integrated-timers")]
task.expires_at.set(u64::MAX);

if !task.state.run_dequeue() {
// If task is not running, ignore it. This can happen in the following scenario:
// - Task gets dequeued, poll starts
// - While task is being polled, it gets woken. It gets placed in the queue.
// - Task poll finishes, returning done=true
// - RUNNING bit is cleared, but the task is already in the queue.
return;
}

#[cfg(feature = "rtos-trace")]
trace::task_exec_begin(p.as_ptr() as u32);

// Run the task
task.poll_fn.get().unwrap_unchecked()(p);
self.run_queue.dequeue_all(|p| {
let task = p.header();

if !task.state.run_dequeue() {
// If task is not running, ignore it. This can happen in the following scenario:
// - Task gets dequeued, poll starts
// - While task is being polled, it gets woken. It gets placed in the queue.
// - Task poll finishes, returning done=true
// - RUNNING bit is cleared, but the task is already in the queue.
return;
}

#[cfg(feature = "rtos-trace")]
trace::task_exec_end();
#[cfg(feature = "rtos-trace")]
trace::task_exec_begin(p.as_ptr() as u32);

// Enqueue or update into timer_queue
#[cfg(feature = "integrated-timers")]
self.timer_queue.update(p);
});
// Run the task
task.poll_fn.get().unwrap_unchecked()(p);

#[cfg(feature = "integrated-timers")]
{
// If this is already in the past, set_alarm might return false
// In that case do another poll loop iteration.
let next_expiration = self.timer_queue.next_expiration();
if embassy_time_driver::set_alarm(self.alarm, next_expiration) {
break;
}
}

#[cfg(not(feature = "integrated-timers"))]
{
break;
}
}
#[cfg(feature = "rtos-trace")]
trace::task_exec_end();
});

#[cfg(feature = "rtos-trace")]
trace::system_idle();
Expand Down Expand Up @@ -583,10 +566,11 @@ struct TimerQueue;
impl embassy_time_queue_driver::TimerQueue for TimerQueue {
fn schedule_wake(&'static self, at: u64, waker: &core::task::Waker) {
let task = waker::task_from_waker(waker);
let task = task.header();
unsafe {
let expires_at = task.expires_at.get();
task.expires_at.set(expires_at.min(at));
critical_section::with(|cs| {
let executor = task.header().executor.get().unwrap_unchecked();
executor.timer_queue.borrow(cs).schedule(task, at);
});
}
}
}
Expand Down
91 changes: 60 additions & 31 deletions embassy-executor/src/raw/timer_queue.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use core::cmp::min;

use super::TaskRef;
use crate::raw::util::SyncUnsafeCell;
use super::util::SyncUnsafeCell;
use super::{AlarmHandle, TaskRef};

pub(crate) struct TimerQueueItem {
next: SyncUnsafeCell<Option<TaskRef>>,
Expand All @@ -17,57 +17,86 @@ impl TimerQueueItem {

pub(crate) struct TimerQueue {
head: SyncUnsafeCell<Option<TaskRef>>,
alarm: AlarmHandle,
}

impl TimerQueue {
pub const fn new() -> Self {
pub const fn new(alarm: AlarmHandle) -> Self {
Self {
head: SyncUnsafeCell::new(None),
alarm,
}
}

pub(crate) unsafe fn update(&self, p: TaskRef) {
pub(crate) unsafe fn notify_task_exited(&self, p: TaskRef) {
let task = p.header();
if task.expires_at.get() != u64::MAX {
if task.state.timer_enqueue() {
task.timer_queue_item.next.set(self.head.get());
self.head.set(Some(p));
}
}

// Trigger removal from the timer queue.
task.expires_at.set(u64::MAX);
self.dispatch(super::wake_task);
}

pub(crate) unsafe fn next_expiration(&self) -> u64 {
let mut res = u64::MAX;
self.retain(|p| {
let task = p.header();
let expires = task.expires_at.get();
res = min(res, expires);
expires != u64::MAX
});
res
pub(crate) unsafe fn schedule(&self, p: TaskRef, at: u64) {
let task = p.header();
let update = if task.state.timer_enqueue() {
// Not in the queue, add it and update.
let prev = self.head.replace(Some(p));
task.timer_queue_item.next.set(prev);

true
} else {
// Expiration is sooner than previously set, update.
at < task.expires_at.get()
};

if update {
task.expires_at.set(at);
self.dispatch(super::wake_task);
}
}

pub(crate) unsafe fn dequeue_expired(&self, now: u64, on_task: impl Fn(TaskRef)) {
self.retain(|p| {
let task = p.header();
if task.expires_at.get() <= now {
on_task(p);
false
} else {
true
pub(crate) unsafe fn dispatch(&self, on_task: fn(TaskRef)) {
loop {
let now = embassy_time_driver::now();

let mut next_expiration = u64::MAX;

self.retain(|p| {
let task = p.header();
let expires = task.expires_at.get();

if expires <= now {
// Timer expired, process task.
on_task(p);
false
} else {
// Timer didn't yet expire, or never expires.
next_expiration = min(next_expiration, expires);
expires != u64::MAX
}
});

if self.update_alarm(next_expiration) {
break;
}
});
}
}

fn update_alarm(&self, next_alarm: u64) -> bool {
if next_alarm == u64::MAX {
true
} else {
embassy_time_driver::set_alarm(self.alarm, next_alarm)
}
}

pub(crate) unsafe fn retain(&self, mut f: impl FnMut(TaskRef) -> bool) {
unsafe fn retain(&self, mut f: impl FnMut(TaskRef) -> bool) {
let mut prev = &self.head;
while let Some(p) = prev.get() {
let task = p.header();
if f(p) {
// Skip to next
prev = &task.timer_queue_item.next;
} else {
// Remove it
prev.set(task.timer_queue_item.next.get());
task.state.timer_dequeue();
}
Expand Down
5 changes: 5 additions & 0 deletions embassy-executor/src/raw/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,9 @@ impl<T> SyncUnsafeCell<T> {
{
*self.value.get()
}

#[cfg(feature = "integrated-timers")]
pub unsafe fn replace(&self, value: T) -> T {
core::mem::replace(&mut *self.value.get(), value)
}
}
7 changes: 5 additions & 2 deletions embassy-stm32/src/time_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,14 @@ impl RtcDriver {
regs_gp16().cnt().write(|w| w.set_cnt(cnt as u16));

// Now, recompute all alarms
for i in 0..ALARM_COUNT {
for i in 0..self.alarm_count.load(Ordering::Relaxed) as usize {
let alarm_handle = unsafe { AlarmHandle::new(i as u8) };
let alarm = self.get_alarm(cs, alarm_handle);

self.set_alarm(alarm_handle, alarm.timestamp.get());
if !self.set_alarm(alarm_handle, alarm.timestamp.get()) {
// If the alarm timestamp has passed, we need to trigger it
self.trigger_alarm(i, cs);
}
}
}

Expand Down