From 58ab058fdd7efffcb2086c5b0b8c82a3403ea206 Mon Sep 17 00:00:00 2001 From: Jari Vetoniemi Date: Mon, 24 Jun 2024 21:51:13 +0900 Subject: [PATCH] coro: take the coro api to different direction --- docs/pages/coro-context-switches.mdx | 34 +- docs/pages/coro-io.mdx | 3 +- docs/pages/coro-scheduler.mdx | 27 +- examples/coro.zig | 19 +- src/aio.zig | 12 +- src/aio/Fallback.zig | 21 +- src/aio/IoUring.zig | 52 ++- src/aio/ops.zig | 7 +- src/coro.zig | 535 +-------------------------- src/coro/Frame.zig | 180 +++++++++ src/coro/Scheduler.zig | 110 ++++++ src/coro/Sync.zig | 42 +++ src/coro/Task.zig | 87 +++++ src/coro/ThreadPool.zig | 83 +++++ src/coro/common.zig | 48 +++ src/coro/io.zig | 50 +++ 16 files changed, 701 insertions(+), 609 deletions(-) create mode 100644 src/coro/Frame.zig create mode 100644 src/coro/Scheduler.zig create mode 100644 src/coro/Sync.zig create mode 100644 src/coro/Task.zig create mode 100644 src/coro/ThreadPool.zig create mode 100644 src/coro/common.zig create mode 100644 src/coro/io.zig diff --git a/docs/pages/coro-context-switches.mdx b/docs/pages/coro-context-switches.mdx index 6f6e1a2..d588448 100644 --- a/docs/pages/coro-context-switches.mdx +++ b/docs/pages/coro-context-switches.mdx @@ -4,38 +4,20 @@ To yield running task to the caller use the following. The function takes a enum value as a argument representing the yield state of the task. +Enum value that corresponds to the integer `0` is resevered to indicate non yield state. ```zig coro.yield(SomeEnum.value); ``` -To continue running the task from where it left, you need to issue the same enum value to the following function. -If you specify the third parameter as `.wait`, the current task will yield (or block if not a task) until the -target task yields to that particural state. If instead third parameter is `.no_wait` then if the task currently -isn't being yielded in the supplied state, the call is a no-op. +The current yield state of a task can be checked with `state` method. +To wakeup a task use the `wakeup` method. When task is woken up the yield state is reset to `0`. +Calling `wakeup` when the task isn't yielded by application's yield state is a error. -```zig -coro.wakeupFromState(task, SomeEnum.value, .wait); -``` - -This is the preferred way to handle the control flow between tasks. - -## Canceling IO - -While it's possible to cancel IO by using the `aio.Cancel` operations. It is also possible to cancel -all IO operations currently blocking a task by doing the following. -If the task currently isn't being yielded by IO then the call is no-op. +Example of checking the current yield state and then waking up the task. ```zig -coro.wakeupFromIo(task); +switch (task.state(SomeEnum)) { + value => task.wakeup(), +} ``` - -## Unpaired wakeup - -Sometimes it's useful to be able to wakeup the task from any yielding state. - -```zig -coro.wakeup(task); -``` - -In this case the task will wake up no matter what its yielding state is currently. diff --git a/docs/pages/coro-io.mdx b/docs/pages/coro-io.mdx index 7678c80..7816b2d 100644 --- a/docs/pages/coro-io.mdx +++ b/docs/pages/coro-io.mdx @@ -25,5 +25,4 @@ outside a task then the call would be equal to calling the equal function from t Use `aio.Cancel` operation to cancel the currently running operations in a task. The `out_error` of such operation will then be set as `error.OperationCanceled`. -Alternatively it's possible to call `scheduler.wakeup(task);` or `scheduler.wakeupFromIo(task)` -which also cancels all currently running IO on that task. +Alternatively it's possible to call `task.complete(.cancel);` to actively cancel a task and collect its partial result. diff --git a/docs/pages/coro-scheduler.mdx b/docs/pages/coro-scheduler.mdx index 79edcb1..736a8e1 100644 --- a/docs/pages/coro-scheduler.mdx +++ b/docs/pages/coro-scheduler.mdx @@ -31,24 +31,21 @@ yields or performs a IO operation using one of the `coro.io` namespace functions var task = try scheduler.spawn(entrypoint, .{ 1, "args" }, .{}); ``` -### Reaping tasks +### Collecting result from a task -Following removes a task, freeing its memory and canceling all running IO operations for that task. -The reap may be delayed in case the task is currently doing IO, the IO operations will be actively canceled. +Use the following to collect a result of a task. +After collecting the result, the task handle is no longer valid. ```zig -scheduler.reap(task); +const res = task.collect(.wait); ``` -Alternatively reap all the tasks using the following. -The reap may be delayed in case the tasks are currently doing IO, the IO operations will be actively canceled. +To cancel and collect partial result. ```zig -scheduler.reapAll(); +const res = task.collect(.cancel); ``` -Call to `deinit` also reaps all tasks. - ### Running The scheduler can process tasks and IO one step a time with the tick method. @@ -56,14 +53,18 @@ By running tick the scheduler will reap tasks that returned (dead tasks) and con tasks in case they completed their IO operations. ```zig -// if there are pending IO operations, blocks until at least one completes +// If there are pending IO operations, blocks until at least one completes try scheduler.tick(.blocking); -// returns immediately regardless of the current IO state +// Returns immediately regardless of the current IO state try scheduler.tick(.nonblocking); ``` -To run the scheduler until all tasks have returned aka died, then use the following. +To run the scheduler until all tasks have completed use the following. +The `mode` option can let you decide whether to wait for all tasks to finish, or to actively try cancel them. ```zig -try scheduler.run(); +// Wait the tasks forever +try scheduler.run(.wait); +// Actively cancel the tasks +try scheduler.run(.cancel); ``` diff --git a/examples/coro.zig b/examples/coro.zig index ed06a98..5b1162f 100644 --- a/examples/coro.zig +++ b/examples/coro.zig @@ -15,11 +15,7 @@ pub const std_options: std.Options = .{ .log_level = .debug, }; -const Yield = enum { - server_ready, -}; - -fn server(client_task: coro.Task) !void { +fn server(lock: *coro.ResetEvent) !void { var socket: std.posix.socket_t = undefined; try coro.io.single(aio.Socket{ .domain = std.posix.AF.INET, @@ -36,7 +32,7 @@ fn server(client_task: coro.Task) !void { try std.posix.bind(socket, &address.any, address.getOsSockLen()); try std.posix.listen(socket, 128); - coro.wakeupFromState(client_task, Yield.server_ready, .wait); + lock.set(); var client_sock: std.posix.socket_t = undefined; try coro.io.single(aio.Accept{ .socket = socket, .out_socket = &client_sock }); @@ -58,7 +54,7 @@ fn server(client_task: coro.Task) !void { }); } -fn client() !void { +fn client(lock: *coro.ResetEvent) !void { var socket: std.posix.socket_t = undefined; try coro.io.single(aio.Socket{ .domain = std.posix.AF.INET, @@ -67,7 +63,7 @@ fn client() !void { .out_socket = &socket, }); - coro.yield(Yield.server_ready); + lock.wait(); const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 1327); try coro.io.single(aio.Connect{ @@ -99,7 +95,8 @@ pub fn main() !void { defer _ = gpa.deinit(); var scheduler = try coro.Scheduler.init(gpa.allocator(), .{}); defer scheduler.deinit(); - const client_task = try scheduler.spawn(client, .{}, .{}); - _ = try scheduler.spawn(server, .{client_task}, .{}); - try scheduler.run(); + var lock: coro.ResetEvent = .{}; + _ = try scheduler.spawn(client, .{&lock}, .{}); + _ = try scheduler.spawn(server, .{&lock}, .{}); + try scheduler.run(.wait); } diff --git a/src/aio.zig b/src/aio.zig index 30c6eb0..9928eed 100644 --- a/src/aio.zig +++ b/src/aio.zig @@ -49,7 +49,7 @@ pub const CompletionResult = struct { /// Queue operations dynamically and complete them on demand pub const Dynamic = struct { pub const Uop = ops.Operation.Union; - pub const Callback = *const fn (uop: Uop) void; + pub const Callback = *const fn (uop: Uop, id: Id, failed: bool) void; io: IO, /// Used by coro implementation callback: ?Callback = null, @@ -135,11 +135,8 @@ pub inline fn multi(operations: anytype) (Error || error{SomeOperationFailed})!v pub inline fn single(operation: anytype) (Error || @TypeOf(operation).Error)!void { var op: @TypeOf(operation) = operation; var err: @TypeOf(operation).Error = error.Success; - if (@hasField(@TypeOf(op), "out_error")) { - op.out_error = &err; - } - _ = try complete(.{op}); - if (err != error.Success) return err; + if (@hasField(@TypeOf(op), "out_error")) op.out_error = &err; + if (try complete(.{op}) > 0) return err; } /// Checks if the current backend supports the operations @@ -231,9 +228,10 @@ test "Nop" { defer dynamic.deinit(std.testing.allocator); try dynamic.queue(Nop{ .domain = @enumFromInt(255), .ident = 69, .userdata = 42 }); const Lel = struct { - fn cb(uop: Dynamic.Uop) void { + fn cb(uop: Dynamic.Uop, _: Id, failed: bool) void { switch (uop) { .nop => |*op| { + std.debug.assert(!failed); std.debug.assert(255 == @intFromEnum(op.domain)); std.debug.assert(69 == op.ident); std.debug.assert(42 == op.userdata); diff --git a/src/aio/Fallback.zig b/src/aio/Fallback.zig index a266273..223eed5 100644 --- a/src/aio/Fallback.zig +++ b/src/aio/Fallback.zig @@ -100,12 +100,8 @@ pub fn deinit(self: *@This(), allocator: std.mem.Allocator) void { } fn initOp(op: anytype, id: u16) void { - if (comptime @hasField(@TypeOf(op.*), "out_id")) { - if (op.out_id) |p_id| p_id.* = @enumFromInt(id); - } - if (comptime @hasField(@TypeOf(op.*), "out_error")) { - if (op.out_error) |out_error| out_error.* = error.Success; - } + if (op.out_id) |p_id| p_id.* = @enumFromInt(id); + if (op.out_error) |out_error| out_error.* = error.Success; } fn addOp(self: *@This(), uop: Operation.Union, linked_to: ?u16, readiness: posix.Readiness) !u16 { @@ -334,10 +330,7 @@ fn submit(self: *@This()) !bool { } fn completition(op: anytype, self: *@This(), res: Result) void { - if (comptime @hasField(@TypeOf(op.*), "out_error")) { - if (op.out_error) |err| err.* = @errorCast(res.failure); - } - + if (op.out_error) |err| err.* = @errorCast(res.failure); if (op.link != .unlinked and self.next[res.id] != res.id) { if (self.ops.nodes[self.next[res.id]].used == .link_timeout) { switch (op.link) { @@ -369,14 +362,18 @@ fn handleFinished(self: *@This(), cb: ?aio.Dynamic.Callback) aio.CompletionResul } else { debug("complete: {}: {} [OK]", .{ res.id, std.meta.activeTag(self.ops.nodes[res.id].used) }); } - defer self.removeOp(res.id); + if (self.ops.nodes[res.id].used == .link_timeout and res.failure == error.OperationCanceled) { // special case } else { num_errors += @intFromBool(res.failure != error.Success); } + uopUnwrapCall(&self.ops.nodes[res.id].used, completition, .{ self, res }); - if (cb) |f| f(self.ops.nodes[res.id].used); + + const uop = self.ops.nodes[res.id].used; + self.removeOp(res.id); + if (cb) |f| f(uop, @enumFromInt(res.id), res.failure != error.Success); } return .{ .num_completed = self.finished_copy.len, .num_errors = num_errors }; diff --git a/src/aio/IoUring.zig b/src/aio/IoUring.zig index 2694730..c8da60a 100644 --- a/src/aio/IoUring.zig +++ b/src/aio/IoUring.zig @@ -98,26 +98,32 @@ pub const NOP = std.math.maxInt(u64); pub fn complete(self: *@This(), mode: aio.Dynamic.CompletionMode, cb: ?aio.Dynamic.Callback) aio.Error!aio.CompletionResult { if (self.ops.empty()) return .{}; + if (mode == .nonblocking) { _ = self.io.nop(NOP) catch |err| return switch (err) { error.SubmissionQueueFull => .{}, }; } + _ = try uring_submit(&self.io); + var result: aio.CompletionResult = .{}; var cqes: [aio.options.io_uring_cqe_sz]std.os.linux.io_uring_cqe = undefined; const n = try uring_copy_cqes(&self.io, &cqes, 1); for (cqes[0..n]) |*cqe| { if (cqe.user_data == NOP) continue; - defer self.ops.remove(@intCast(cqe.user_data)); - const uop = self.ops.get(@intCast(cqe.user_data)); - switch (uop.*) { + const uop = self.ops.get(@intCast(cqe.user_data)).*; + var failed: bool = false; + switch (uop) { inline else => |*op| uring_handle_completion(op, cqe) catch { result.num_errors += 1; + failed = true; }, } - if (cb) |f| f(uop.*); + self.ops.remove(@intCast(cqe.user_data)); + if (cb) |f| f(uop, @enumFromInt(cqe.user_data), failed); } + result.num_completed = n - @intFromBool(mode == .nonblocking); return result; } @@ -271,12 +277,8 @@ inline fn uring_queue(io: *std.os.linux.IoUring, op: anytype, user_data: u64) ai .soft => sqe.flags |= std.os.linux.IOSQE_IO_LINK, .hard => sqe.flags |= std.os.linux.IOSQE_IO_HARDLINK, } - if (comptime @hasField(@TypeOf(op.*), "out_id")) { - if (op.out_id) |id| id.* = @enumFromInt(user_data); - } - if (comptime @hasField(@TypeOf(op.*), "out_error")) { - if (op.out_error) |out_error| out_error.* = error.Success; - } + if (op.out_id) |id| id.* = @enumFromInt(user_data); + if (op.out_error) |out_error| out_error.* = error.Success; } inline fn uring_submit(io: *std.os.linux.IoUring) aio.Error!u16 { @@ -320,7 +322,10 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe) const err = cqe.err(); if (err != .SUCCESS) { const res: @TypeOf(op.*).Error = switch (comptime Operation.tagFromPayloadType(@TypeOf(op.*))) { - .nop => unreachable, + .nop => switch (err) { + .CANCELED => error.OperationCanceled, + else => std.posix.unexpectedErrno(err), + }, .fsync => switch (err) { .SUCCESS, .INTR, .INVAL, .FAULT, .AGAIN, .ROFS => unreachable, .BADF => unreachable, // not a file @@ -454,25 +459,32 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe) .ILSEQ => error.InvalidUtf8, else => std.posix.unexpectedErrno(err), }, - .close_file, .close_dir, .close_socket => unreachable, - .notify_event_source, .wait_event_source, .close_event_source => unreachable, + .close_file, .close_dir, .close_socket => switch (err) { + .CANCELED => error.OperationCanceled, + else => std.posix.unexpectedErrno(err), + }, + .notify_event_source, .wait_event_source, .close_event_source => switch (err) { + .CANCELED => error.OperationCanceled, + else => std.posix.unexpectedErrno(err), + }, .timeout => switch (err) { .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, .TIME => error.Success, .CANCELED => error.OperationCanceled, - else => unreachable, + else => std.posix.unexpectedErrno(err), }, .link_timeout => switch (err) { .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, .TIME => error.Expired, .ALREADY, .CANCELED => error.OperationCanceled, - else => unreachable, + else => std.posix.unexpectedErrno(err), }, .cancel => switch (err) { .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, .ALREADY => error.InProgress, .NOENT => error.NotFound, - else => unreachable, + .CANCELED => error.OperationCanceled, + else => std.posix.unexpectedErrno(err), }, .rename_at => switch (err) { .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, @@ -519,6 +531,7 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe) }, .mkdir_at => switch (err) { .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, + .CANCELED => error.OperationCanceled, .ACCES => error.AccessDenied, .BADF => unreachable, .PERM => error.AccessDenied, @@ -539,6 +552,7 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe) }, .symlink_at => switch (err) { .SUCCESS, .INTR, .INVAL, .AGAIN, .FAULT => unreachable, + .CANCELED => error.OperationCanceled, .ACCES => error.AccessDenied, .PERM => error.AccessDenied, .DQUOT => error.DiskQuota, @@ -555,11 +569,13 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe) }, .child_exit => switch (err) { .SUCCESS, .INTR, .AGAIN, .FAULT, .INVAL => unreachable, + .CANCELED => error.OperationCanceled, .CHILD => error.NotFound, else => std.posix.unexpectedErrno(err), }, .socket => switch (err) { .SUCCESS, .INTR, .AGAIN, .FAULT => unreachable, + .CANCELED => error.OperationCanceled, .ACCES => error.PermissionDenied, .AFNOSUPPORT => error.AddressFamilyNotSupported, .INVAL => error.ProtocolFamilyNotAvailable, @@ -573,9 +589,7 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe) }, }; - if (comptime @hasField(@TypeOf(op.*), "out_error")) { - if (op.out_error) |out_error| out_error.* = res; - } + if (op.out_error) |out_error| out_error.* = res; if (res != error.Success) { if ((comptime Operation.tagFromPayloadType(@TypeOf(op.*)) == .link_timeout) and res == error.OperationCanceled) { diff --git a/src/aio/ops.zig b/src/aio/ops.zig index 06d1857..391daf5 100644 --- a/src/aio/ops.zig +++ b/src/aio/ops.zig @@ -20,13 +20,15 @@ const SharedError = error{ /// Can be used to wakeup the backend, custom notifications, etc... pub const Nop = struct { - pub const Error = error{Success}; + pub const Error = SharedError; pub const Domain = enum(u8) { coro, // reserved, please don't use _, }; domain: Domain, ident: usize, + out_id: ?*Id = null, + out_error: ?*Error = null, link: Link = .unlinked, userdata: usize = 0, }; @@ -175,8 +177,9 @@ pub const LinkTimeout = struct { /// Cancel a operation pub const Cancel = struct { - pub const Error = error{ Success, InProgress, NotFound }; + pub const Error = error{ InProgress, NotFound } || SharedError; id: Id, + out_id: ?*Id = null, out_error: ?*Error = null, link: Link = .unlinked, userdata: usize = 0, diff --git a/src/coro.zig b/src/coro.zig index bf3c88f..9178efc 100644 --- a/src/coro.zig +++ b/src/coro.zig @@ -17,97 +17,32 @@ pub const Options = struct { io_queue_entries: u16 = 4096, /// Default stack size for coroutines stack_size: usize = 1.049e+6, // 1 MiB - /// Default handler for errors for !void coroutines - error_handler: fn (err: anyerror) void = defaultErrorHandler, }; -fn defaultErrorHandler(err: anyerror) void { - std.debug.print("error: {s}\n", .{@errorName(err)}); - if (@errorReturnTrace()) |trace| { - std.debug.dumpStackTrace(trace.*); - } -} - -fn debug(comptime fmt: []const u8, args: anytype) void { - if (@import("builtin").is_test) { - std.debug.print("coro: " ++ fmt ++ "\n", args); - } else { - if (comptime !options.debug) return; - const log = std.log.scoped(.coro); - log.debug(fmt, args); - } -} - -pub const io = struct { - fn privateComplete(operations: anytype, yield_state: YieldState) aio.Error!u16 { - if (Fiber.current()) |fiber| { - var task: *TaskState = @ptrFromInt(fiber.getUserDataPtr().*); - - const State = struct { - old_userdata: usize, - old_err: ?*anyerror, - old_id: ?*aio.Id, - id: aio.Id, - err: anyerror, - }; +pub const Scheduler = @import("coro/Scheduler.zig"); +pub const ThreadPool = @import("coro/ThreadPool.zig"); +pub const Task = @import("coro/Task.zig"); - var state: [operations.len]State = undefined; - var work = struct { ops: @TypeOf(operations) }{ .ops = operations }; - inline for (&work.ops, &state) |*op, *s| { - s.err = error.Success; - if (@hasDecl(@TypeOf(op.*), "out_id")) { - s.old_id = op.out_id; - op.out_id = &s.id; - } else { - s.old_id = null; - } - if (@hasDecl(@TypeOf(op.*), "out_error")) { - s.old_err = op.out_error; - op.out_error = @ptrCast(&s.err); - } else { - s.old_err = null; - } - s.old_userdata = op.userdata; - op.userdata = @intFromPtr(task); - } +const Sync = @import("coro/Sync.zig"); +pub const Semaphore = Sync.Semaphore; +pub const ResetEvent = Sync.ResetEvent; - try task.scheduler.io.queue(work.ops); - task.io_counter = operations.len; - privateYield(yield_state); - - if (task.io_counter > 0) { - // woken up for io cancelation - // TODO: don't cancel already completed - // Currently this code is broken - if (false) { - var cancels: [operations.len]aio.Cancel = undefined; - inline for (&cancels, &state) |*op, *s| op.* = .{ .id = s.id, .userdata = @intFromPtr(task) }; - try task.scheduler.io.queue(cancels); - task.io_counter += @intCast(operations.len - cancels.len); - privateYield(.io_cancel); - } - unreachable; - } - - var num_errors: u16 = 0; - inline for (&work.ops, &state) |*op, *s| { - num_errors += @intFromBool(s.err != error.Success); - if (s.old_err) |p| p.* = s.err; - if (s.old_id) |p| p.* = s.id; - op.userdata = s.old_userdata; - } - return num_errors; - } else { - return aio.complete(operations); - } - } +// Aliases +pub const current = Task.current; +pub const yield = Task.yield; +/// Nonblocking IO (on a task) +pub const io = struct { /// Completes a list of operations immediately, blocks the coroutine until complete /// The IO operations can be cancelled by calling `wakeup` /// For error handling you must check the `out_error` field in the operation /// Returns the number of errors occured, 0 if there were no errors pub inline fn complete(operations: anytype) aio.Error!u16 { - return privateComplete(operations, .io); + const do = @import("coro/io.zig").do; + return do(operations, .io) catch |err| switch (err) { + error.OperationCanceled => operations.len, + else => |e| e, + }; } /// Completes a list of operations immediately, blocks until complete @@ -122,441 +57,7 @@ pub const io = struct { pub inline fn single(operation: anytype) (aio.Error || @TypeOf(operation).Error)!void { var op: @TypeOf(operation) = operation; var err: @TypeOf(operation).Error = error.Success; - if (@hasField(@TypeOf(op), "out_error")) { - op.out_error = &err; - } - _ = try complete(.{op}); - if (err != error.Success) return err; - } -}; - -/// Is the code currently running inside a task? -pub inline fn isTask() bool { - return Fiber.current() != null; -} - -/// Yields current task, can only be called from inside a task -pub inline fn yield(state: anytype) void { - privateYield(@enumFromInt(std.meta.fields(YieldState).len + @intFromEnum(state))); -} - -pub const WakeupMode = enum { no_wait, wait }; - -/// Wakeups a task from a yielded state, no-op if `state` does not match the current yielding state -/// `mode` lets you select whether to `wait` for the state to change before trying to wake up or not -pub inline fn wakeupFromState(task: Task, state: anytype, mode: WakeupMode) void { - const node: *Scheduler.Tasks.Node = @ptrCast(task); - node.data.cast().wakeup(@enumFromInt(std.meta.fields(YieldState).len + @intFromEnum(state)), mode); -} - -/// Wakeups a task from IO by canceling the current IO operations for that task -pub inline fn wakeupFromIo(task: Task) void { - const node: *Scheduler.Tasks.Node = @ptrCast(task); - node.data.cast().wakeup(.io, .no_wait); -} - -/// Wakeups a task regardless of the current yielding state -pub inline fn wakeup(task: Task) void { - const node: *Scheduler.Tasks.Node = @ptrCast(task); - const state = node.data.cast(); - // do not wakeup from io_cancel state as that can potentially lead to memory corruption - if (state.yield_state == .io_cancel) return; - // ditto for io_waiting_thread - if (state.yield_state == .io_waiting_thread) return; - state.wakeup(state.yield_state, .no_wait); -} - -const YieldState = enum(u8) { - not_yielding, // task is running - waiting_for_yield, // waiting for another task to change yield state - io, // waiting for io - io_waiting_thread, // cannot be canceled - io_cancel, // cannot be canceled - _, // fields after are reserved for custom use - - pub fn format(self: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { - if (@intFromEnum(self) < std.meta.fields(@This()).len) { - try writer.writeAll(@tagName(self)); - } else { - try writer.print("custom {}", .{@intFromEnum(self)}); - } - } -}; - -fn privateYield(state: YieldState) void { - if (Fiber.current()) |fiber| { - var task: *TaskState = @ptrFromInt(fiber.getUserDataPtr().*); - std.debug.assert(task.yield_state == .not_yielding); - task.yield_state = state; - debug("yielding: {}", .{task}); - if (@intFromEnum(state) >= std.meta.fields(YieldState).len and (task.waiters.first != null or task.scheduler_is_waiting)) { - task.scheduler.io.queue(aio.Nop{ - .domain = .coro, - .ident = @intFromEnum(YieldState.waiting_for_yield), - .userdata = @intFromPtr(task), - }) catch @panic("cannot yield, the submission queue is full"); - } - Fiber.yield(); - } else { - unreachable; // yield can only be used from a task - } -} - -const TaskState = struct { - const Waiters = std.SinglyLinkedList(Link(TaskState, "waiter_link", .single)); - fiber: *Fiber, - stack: ?Fiber.Stack = null, - marked_for_reap: bool = false, - yield_state: YieldState = .not_yielding, - scheduler: *Scheduler, - io_counter: u16 = 0, - scheduler_is_waiting: bool = false, - waiters: Waiters = .{}, - waiter_link: Waiters.Node = .{ .data = .{} }, - link: Scheduler.Tasks.Node = .{ .data = .{} }, - - pub fn format(self: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { - if (self.io_counter > 0) { - try writer.print("{x}: {}, {} ops left", .{ @intFromPtr(self.fiber), self.yield_state, self.io_counter }); - } else { - try writer.print("{x}: {}", .{ @intFromPtr(self.fiber), self.yield_state }); - } - } - - fn deinit(self: *@This(), allocator: std.mem.Allocator) void { - // we can only safely deinit the task when it is not doing IO - // otherwise for example io_uring might write tok a invalid memory address - debug("deinit: {}", .{self}); - std.debug.assert(self.isReapable()); - if (self.stack) |stack| allocator.free(stack); - // stack is now gone, doing anything after this with TaskState is ub - } - - fn wakeup(self: *TaskState, state: YieldState, mode: WakeupMode) void { - if (mode == .wait) { - debug("waiting to wake up: {} when it yields {}", .{ self, state }); - while (self.yield_state != state and !self.marked_for_reap) { - if (Fiber.current()) |fiber| { - var task: *TaskState = @ptrFromInt(fiber.getUserDataPtr().*); - self.waiters.prepend(&task.waiter_link); - privateYield(.waiting_for_yield); - } else { - std.debug.assert(!self.scheduler_is_waiting); // deadlock - self.scheduler_is_waiting = true; - self.scheduler.tick(.blocking) catch @panic("unrecovable"); - self.scheduler_is_waiting = false; - } - } - } - if (self.marked_for_reap) return; - if (self.yield_state != state) return; - debug("waking up: {}", .{self}); - self.yield_state = .not_yielding; - self.fiber.switchTo(); - } - - inline fn isReapable(self: *TaskState) bool { - return switch (self.yield_state) { - .io, .io_waiting_thread, .io_cancel => false, - else => true, - }; - } -}; - -pub const Task = *Scheduler.Tasks.Node; - -/// Runtime for asynchronous IO tasks -pub const Scheduler = struct { - const Tasks = std.DoublyLinkedList(Link(TaskState, "link", .double)); - allocator: std.mem.Allocator, - io: aio.Dynamic, - tasks: Tasks = .{}, - tasks_pending_reap: Tasks = .{}, - - pub const InitOptions = struct { - /// This is a hint, the implementation makes the final call - io_queue_entries: u16 = options.io_queue_entries, - }; - - pub fn init(allocator: std.mem.Allocator, opts: InitOptions) aio.Error!@This() { - var work = try aio.Dynamic.init(allocator, opts.io_queue_entries); - work.callback = ioComplete; - return .{ .allocator = allocator, .io = work }; - } - - pub fn reapAll(self: *@This()) void { - while (self.tasks.pop()) |node| { - if (!node.data.cast().marked_for_reap) { - node.data.cast().marked_for_reap = true; - self.tasks_pending_reap.append(node); - } - self.privateReap(node); - } - } - - fn privateReap(self: *@This(), node: *Tasks.Node) void { - var task = node.data.cast(); - if (!task.isReapable()) { - if (task.yield_state == .io) { - debug("task is pending on io, reaping later: {}", .{task}); - // TODO: io cancellations are broken - // task.wakeup(.io, .no_wait); // cancel io - } - if (!task.marked_for_reap) { - task.marked_for_reap = true; - self.tasks_pending_reap.append(node); - } - return; // still pending - } - if (task.marked_for_reap) self.tasks_pending_reap.remove(node); - task.deinit(self.allocator); - } - - pub fn reap(self: *@This(), node: Task) void { - self.tasks.remove(@ptrCast(node)); - self.privateReap(@ptrCast(node)); - } - - pub fn deinit(self: *@This()) void { - // destroy io backend first to make sure we can destroy the tasks safely - self.io.deinit(self.allocator); - inline for (&.{ &self.tasks, &self.tasks_pending_reap }) |list| { - while (list.pop()) |node| { - var task = node.data.cast(); - // modify the yield state to avoid state consistency assert in deinit - // it's okay to deinit now since the io backend is dead - // if using thread pool, it is error to deinit scheduler before thread - // pool is deinited, so we don't have to wait for those tasks either - task.yield_state = .not_yielding; - task.deinit(self.allocator); - } - } - self.* = undefined; - } - - inline fn entrypoint(self: *@This(), stack: Fiber.Stack, comptime func: anytype, args: anytype) void { - var state: TaskState = .{ - .fiber = Fiber.current().?, - .scheduler = self, - .stack = stack, - }; - state.fiber.getUserDataPtr().* = @intFromPtr(&state); - self.tasks.append(&state.link); - debug("spawned: {}", .{state}); - - if (@typeInfo(@typeInfo(@TypeOf(func)).Fn.return_type.?) == .ErrorUnion) { - @call(.auto, func, args) catch |err| options.error_handler(err); - } else { - @call(.auto, func, args); - } - - debug("finished: {}", .{state}); - - if (!state.marked_for_reap) { - self.tasks.remove(&state.link); - if (state.stack) |_| { - // stack is managed, it needs to be cleaned outside - state.marked_for_reap = true; - self.tasks_pending_reap.append(&state.link); - } - } - } - - pub const SpawnError = error{OutOfMemory} || Fiber.Error; - - pub const SpawnOptions = struct { - stack: union(enum) { - unmanaged: Fiber.Stack, - managed: usize, - } = .{ .managed = options.stack_size }, - }; - - /// Spawns a new task, the task may do local IO operations which will not block the whole process using the `io` namespace functions - pub fn spawn(self: *@This(), comptime func: anytype, args: anytype, opts: SpawnOptions) SpawnError!Task { - const stack = switch (opts.stack) { - .unmanaged => |buf| buf, - .managed => |sz| try self.allocator.alignedAlloc(u8, Fiber.stack_alignment, sz), - }; - errdefer if (opts.stack == .managed) self.allocator.free(stack); - var fiber = try Fiber.init(stack, 0, entrypoint, .{ self, stack, func, args }); - fiber.switchTo(); - return self.tasks.last.?; - } - - fn ioComplete(uop: aio.Dynamic.Uop) void { - switch (uop) { - inline else => |*op| { - std.debug.assert(op.userdata != 0); - var task: *TaskState = @ptrFromInt(op.userdata); - if (@TypeOf(op.*) == aio.Nop) { - if (op.domain != .coro) return; - if (op.ident == @intFromEnum(YieldState.waiting_for_yield)) { - while (task.waiters.popFirst()) |waiter| { - waiter.data.cast().wakeup(.waiting_for_yield, .no_wait); - } - } - } else { - task.io_counter -= 1; - if (task.io_counter == 0) { - task.wakeup(task.yield_state, .no_wait); - } - } - }, - } - } - - /// Processes pending IO and reaps dead tasks - pub fn tick(self: *@This(), mode: aio.Dynamic.CompletionMode) aio.Error!void { - _ = try self.io.complete(mode); - var next = self.tasks_pending_reap.first; - while (next) |node| { - next = node.next; - self.privateReap(node); - } - } - - /// Run until all tasks are dead - pub fn run(self: *@This()) aio.Error!void { - while (self.tasks.len > 0 or self.tasks_pending_reap.len > 0) { - try self.tick(.blocking); - } - } -}; - -test "Wakeup" { - const Yield = enum { - wakeup_plz, - }; - const Test = struct { - fn task() !void { - try io.single(aio.Timeout{ .ns = 1 }); - yield(Yield.wakeup_plz); - } - }; - var scheduler = try Scheduler.init(std.testing.allocator, .{}); - defer scheduler.deinit(); - const task = try scheduler.spawn(Test.task, .{}, .{}); - wakeupFromState(task, Yield.wakeup_plz, .wait); - try scheduler.run(); -} - -/// Synchronizes blocking tasks on a thread pool -pub const ThreadPool = struct { - pool: std.Thread.Pool = undefined, - source: aio.EventSource = undefined, - num_tasks: std.atomic.Value(u32) = std.atomic.Value(u32).init(0), - - /// Spin up the pool, `allocator` is used to allocate the tasks - /// If `num_jobs` is zero, the thread count for the current CPU is used - pub fn start(self: *@This(), allocator: std.mem.Allocator, num_jobs: u32) !void { - self.* = .{ .pool = .{ .allocator = undefined, .threads = undefined }, .source = try aio.EventSource.init() }; - errdefer self.source.deinit(); - try self.pool.init(.{ .allocator = allocator, .n_jobs = if (num_jobs == 0) null else num_jobs }); - } - - pub fn deinit(self: *@This()) void { - self.pool.deinit(); - self.source.deinit(); - self.* = undefined; - } - - inline fn entrypoint(self: *@This(), completed: *bool, comptime func: anytype, ret: anytype, args: anytype) void { - ret.* = @call(.auto, func, args); - completed.* = true; - const n = self.num_tasks.load(.acquire); - for (0..n) |_| self.source.notify(); - } - - const Error = error{SomeOperationFailed} || aio.Error || aio.EventSource.Error || aio.WaitEventSource.Error; - - fn ReturnType(comptime Func: type) type { - const base = @typeInfo(Func).Fn.return_type.?; - return switch (@typeInfo(base)) { - .ErrorUnion => |eu| (Error || eu.error_set)!eu.payload, - else => Error!base, - }; - } - - /// Yield until `func` finishes on another thread - pub fn yieldForCompletition(self: *@This(), func: anytype, args: anytype) ReturnType(@TypeOf(func)) { - var completed: bool = false; - var ret: @typeInfo(@TypeOf(func)).Fn.return_type.? = undefined; - _ = self.num_tasks.fetchAdd(1, .monotonic); - defer _ = self.num_tasks.fetchSub(1, .release); - try self.pool.spawn(entrypoint, .{ self, &completed, func, &ret, args }); - var wait_err: aio.WaitEventSource.Error = error.Success; - while (!completed) { - if (try io.privateComplete(.{ - aio.WaitEventSource{ .source = &self.source, .link = .soft, .out_error = &wait_err }, - }, .io_waiting_thread) > 0) { - // it's possible to end up here if aio implementation ran out of resources - // in case of io_uring the application managed to fill up the submission queue - // normally this should not happen, but as to not crash the program do a blocking wait - self.source.wait(); - } - } - return ret; + if (@hasField(@TypeOf(op), "out_error")) op.out_error = &err; + if (try complete(.{op}) > 0) return err; } }; - -test "ThreadPool" { - const Test = struct { - const Yield = enum { - task2_free, - }; - - fn task1(task1_done: *bool) void { - yield(Yield.task2_free); - task1_done.* = true; - } - - fn blocking() u32 { - std.time.sleep(1 * std.time.ns_per_s); - return 69; - } - - fn task2(t1: Task, pool: *ThreadPool, task1_done: *bool) !void { - const ret = try pool.yieldForCompletition(blocking, .{}); - try std.testing.expectEqual(69, ret); - try std.testing.expectEqual(task1_done.*, false); - wakeupFromState(t1, Yield.task2_free, .no_wait); - try std.testing.expectEqual(task1_done.*, true); - } - - fn task3(pool: *ThreadPool) !void { - const ret = try pool.yieldForCompletition(blocking, .{}); - try std.testing.expectEqual(69, ret); - } - }; - var scheduler = try Scheduler.init(std.testing.allocator, .{}); - defer scheduler.deinit(); - var pool: ThreadPool = .{}; - try pool.start(std.testing.allocator, 0); - defer pool.deinit(); - var task1_done: bool = false; - const task1 = try scheduler.spawn(Test.task1, .{&task1_done}, .{}); - _ = try scheduler.spawn(Test.task2, .{ task1, &pool, &task1_done }, .{}); - for (0..10) |_| _ = try scheduler.spawn(Test.task3, .{&pool}, .{}); - try scheduler.run(); -} - -fn Link(comptime T: type, comptime field: []const u8, comptime container: enum { single, double }) type { - return struct { - pub fn format(self: *@This(), comptime fmt: []const u8, opts: std.fmt.FormatOptions, writer: anytype) !void { - return self.cast().format(self, fmt, opts, writer); - } - - inline fn cast(self: *@This()) *T { - switch (container) { - .single => { - const node: *std.SinglyLinkedList(@This()).Node = @alignCast(@fieldParentPtr("data", self)); - return @fieldParentPtr(field, node); - }, - .double => { - const node: *std.DoublyLinkedList(@This()).Node = @alignCast(@fieldParentPtr("data", self)); - return @fieldParentPtr(field, node); - }, - } - } - }; -} diff --git a/src/coro/Frame.zig b/src/coro/Frame.zig new file mode 100644 index 0000000..f9fdbed --- /dev/null +++ b/src/coro/Frame.zig @@ -0,0 +1,180 @@ +const std = @import("std"); +const aio = @import("aio"); +const io = @import("io.zig"); +const Fiber = @import("zefi.zig"); +const Scheduler = @import("Scheduler.zig"); +const common = @import("common.zig"); + +pub const List = std.DoublyLinkedList(common.Link(@This(), "link", .double)); +pub const stack_alignment = Fiber.stack_alignment; +pub const Stack = Fiber.Stack; + +pub const Status = enum(u8) { + active, // frame is running + io, // waiting for io + io_cancel, // cannot be canceled + completed, // the frame is complete and has to be collected + semaphore, // waiting on a semaphore + reset_event, // waiting on a reset_event + yield, // yielded by a user code + + pub fn format(self: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + try writer.writeAll(@tagName(self)); + } +}; + +pub const WaitList = std.SinglyLinkedList(common.Link(@This(), "wait_link", .single)); + +fiber: *Fiber, +stack: ?Fiber.Stack = null, +scheduler: *Scheduler, +canceled: bool = false, +status: Status = .active, +result: *anyopaque, +link: List.Node = .{ .data = .{} }, +waiters: WaitList = .{}, +wait_link: WaitList.Node = .{ .data = .{} }, +yield_state: u8 = 0, + +pub fn current() ?*@This() { + if (Fiber.current()) |fiber| { + return @ptrFromInt(fiber.getUserDataPtr().*); + } else { + return null; + } +} + +pub fn format(self: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + try writer.print("{x}!{}", .{ @intFromPtr(self.fiber), self.status }); +} + +pub const Error = error{OutOfMemory} || Fiber.Error; + +inline fn entrypoint( + scheduler: *Scheduler, + stack: Fiber.Stack, + Result: type, + out_frame: **@This(), + comptime func: anytype, + args: anytype, +) void { + var res: Result = undefined; + var frame: @This() = .{ + .fiber = Fiber.current().?, + .stack = stack, + .scheduler = scheduler, + .result = &res, + }; + frame.fiber.getUserDataPtr().* = @intFromPtr(&frame); + scheduler.frames.append(&frame.link); + out_frame.* = &frame; + + debug("spawned: {}", .{frame}); + res = @call(.always_inline, func, args); + scheduler.num_complete += 1; + + // keep the stack alive, until the task is collected + yield(.completed); +} + +pub fn init(scheduler: *Scheduler, stack: Stack, Result: type, comptime func: anytype, args: anytype) Error!*@This() { + var frame: *@This() = undefined; + var fiber = try Fiber.init(stack, 0, entrypoint, .{ scheduler, stack, Result, &frame, func, args }); + fiber.switchTo(); + return frame; +} + +pub fn deinit(self: *@This()) void { + debug("deinit: {}", .{self}); + std.debug.assert(self.status == .completed); + self.scheduler.frames.remove(&self.link); + self.scheduler.num_complete -= 1; + while (self.waiters.popFirst()) |node| node.data.cast().wakeup(.reset_event); + if (self.stack) |stack| self.scheduler.allocator.free(stack); + // stack is now gone, doing anything after this with @This() is ub +} + +pub inline fn wakeup(self: *@This(), expected_status: Status) void { + std.debug.assert(self.status == expected_status); + debug("waking up: {}", .{self}); + self.status = .active; + self.fiber.switchTo(); +} + +pub fn yield(status: Status) void { + if (current()) |frame| { + if (status != .completed and status != .io_cancel and frame.canceled) return; + std.debug.assert(frame.status == .active); + frame.status = status; + debug("yielding: {}", .{frame}); + Fiber.yield(); + } else { + unreachable; // yield can only be used from a frame + } +} + +pub const CompleteMode = enum { wait, cancel }; + +pub fn complete(self: *@This(), mode: CompleteMode, comptime Result: type) Result { + if (self.canceled) return; + + debug("complete: {}, {}", .{ self, mode }); + + if (current()) |frame| { + while (self.status != .completed) { + if (mode == .cancel and tryCancel(self)) break; + self.waiters.prepend(&frame.wait_link); + yield(.reset_event); + } + } else { + while (self.status != .completed) { + if (mode == .cancel and tryCancel(self)) break; + if (self.scheduler.tick(.nonblocking) catch 0 == 0) break; + } + } + + const res: Result = @as(*Result, @ptrCast(@alignCast(self.result))).*; + self.deinit(); + return res; +} + +pub fn tryCancel(self: *@This()) bool { + if (self.status != .completed) { + self.canceled = true; + switch (self.status) { + .active => {}, + .io => { + debug("cancel... pending on io: {}", .{self}); + self.wakeup(.io); // cancel io + }, + .io_cancel => { + debug("cancel... pending on io: {}", .{self}); + }, + .reset_event => { + debug("cancel... reset event: {}", .{self}); + self.wakeup(.reset_event); + }, + .semaphore => { + debug("cancel... semaphore: {}", .{self}); + self.wakeup(.semaphore); + }, + .yield => { + debug("cancel... yield: {}", .{self}); + self.wakeup(.yield); + }, + .completed => unreachable, + } + } + return self.status == .completed; +} + +fn debug(comptime fmt: []const u8, args: anytype) void { + if (@import("builtin").is_test) { + std.debug.print("coro: " ++ fmt ++ "\n", args); + } else { + const options = @import("../coro.zig").options; + if (comptime !options.debug) return; + const scope = std.log.scoped(.coro); + scope.debug(fmt, args); + } +} diff --git a/src/coro/Scheduler.zig b/src/coro/Scheduler.zig new file mode 100644 index 0000000..4a31549 --- /dev/null +++ b/src/coro/Scheduler.zig @@ -0,0 +1,110 @@ +const std = @import("std"); +const aio = @import("aio"); +const io = @import("io.zig"); +const Frame = @import("Frame.zig"); +const Task = @import("Task.zig"); +const common = @import("common.zig"); +const options = @import("../coro.zig").options; + +allocator: std.mem.Allocator, +io: aio.Dynamic, +frames: Frame.List = .{}, +num_complete: usize = 0, + +fn ioComplete(uop: aio.Dynamic.Uop, id: aio.Id, failed: bool) void { + switch (uop) { + inline else => |*op| { + std.debug.assert(op.userdata != 0); + if (@TypeOf(op.*) == aio.Nop) { + if (op.domain != .coro) return; + switch (op.ident) { + 's' => {}, + 'r' => {}, + else => unreachable, + } + } else { + var ctx: *common.OperationContext = @ptrFromInt(op.userdata); + var frame: *Frame = ctx.whole.frame; + std.debug.assert(ctx.whole.num_operations > 0); + ctx.id = id; + ctx.completed = true; + ctx.whole.num_operations -= 1; + ctx.whole.num_errors += @intFromBool(failed); + if (ctx.whole.num_operations == 0) { + switch (frame.status) { + .io, .io_cancel => frame.wakeup(frame.status), + else => unreachable, + } + } + } + }, + } +} + +pub const InitOptions = struct { + /// This is a hint, the implementation makes the final call + io_queue_entries: u16 = options.io_queue_entries, +}; + +pub fn init(allocator: std.mem.Allocator, opts: InitOptions) aio.Error!@This() { + var work = try aio.Dynamic.init(allocator, opts.io_queue_entries); + work.callback = ioComplete; + return .{ .allocator = allocator, .io = work }; +} + +pub fn deinit(self: *@This()) void { + self.run(.cancel) catch @panic("unrecovable"); + var next = self.frames.first; + while (next) |node| { + next = node.next; + var frame = node.data.cast(); + frame.deinit(); + } + self.io.deinit(self.allocator); + self.* = undefined; +} + +pub const SpawnError = Frame.Error; + +pub const SpawnOptions = struct { + stack: union(enum) { + unmanaged: Frame.Stack, + managed: usize, + } = .{ .managed = options.stack_size }, +}; + +/// Spawns a new task, the task may do local IO operations which will not block the whole process using the `io` namespace functions +/// Call `frame.complete` to collect the result and free the stack +pub fn spawn(self: *@This(), comptime func: anytype, args: anytype, opts: SpawnOptions) SpawnError!Task.Generic(func) { + const stack = switch (opts.stack) { + .unmanaged => |buf| buf, + .managed => |sz| try self.allocator.alignedAlloc(u8, Frame.stack_alignment, sz), + }; + errdefer if (opts.stack == .managed) self.allocator.free(stack); + const frame = try Frame.init(self, stack, Task.Generic(func).Result, func, args); + return .{ .frame = frame }; +} + +/// Step the scheduler by a single step. +/// If `mode` is `.blocking` will block until there is `IO` activity. +/// Returns the number of tasks running. +pub fn tick(self: *@This(), mode: aio.Dynamic.CompletionMode) aio.Error!usize { + _ = try self.io.complete(mode); + return self.frames.len - self.num_complete; +} + +pub const CompleteMode = Frame.CompleteMode; + +/// Run until all tasks are complete. +pub fn run(self: *@This(), mode: CompleteMode) aio.Error!void { + while (true) { + if (mode == .cancel) { + var next = self.frames.first; + while (next) |node| { + next = node.next; + _ = node.data.cast().tryCancel(); + } + } + if (try self.tick(.blocking) == 0) break; + } +} diff --git a/src/coro/Sync.zig b/src/coro/Sync.zig new file mode 100644 index 0000000..2c3425f --- /dev/null +++ b/src/coro/Sync.zig @@ -0,0 +1,42 @@ +const Frame = @import("Frame.zig"); + +pub const Semaphore = struct { + waiters: Frame.WaitList = .{}, + counter: u32 = 0, + + pub fn lock(self: *@This()) void { + self.counter += 1; + if (Frame.current()) |frame| { + self.waiters.prepend(&frame.wait_link); + while (self.counter > 0) Frame.yield(.semaphore); + } else unreachable; // can only be used in tasks + } + + pub fn unlock(self: *@This()) void { + if (self.counter == 0) return; + self.counter -= 1; + while (self.waiters.popFirst()) |node| node.data.cast().wakeup(.semaphore); + } +}; + +pub const ResetEvent = struct { + waiters: Frame.WaitList = .{}, + is_set: bool = false, + + pub fn wait(self: *@This()) void { + if (self.is_set) return; + if (Frame.current()) |frame| { + self.waiters.prepend(&frame.wait_link); + while (!self.is_set) Frame.yield(.reset_event); + } else unreachable; // can only be used in tasks + } + + pub fn set(self: *@This()) void { + self.is_set = true; + while (self.waiters.popFirst()) |node| node.data.cast().wakeup(.reset_event); + } + + pub fn reset(self: *@This()) void { + self.is_set = false; + } +}; diff --git a/src/coro/Task.zig b/src/coro/Task.zig new file mode 100644 index 0000000..c153490 --- /dev/null +++ b/src/coro/Task.zig @@ -0,0 +1,87 @@ +const std = @import("std"); +const ReturnType = @import("common.zig").ReturnType; +const Frame = @import("Frame.zig"); + +const Task = @This(); + +frame: *align(@alignOf(Frame)) anyopaque, + +fn cast(self: @This()) *Frame { + return @ptrCast(self.frame); +} + +pub fn format(self: @This(), comptime fmt: []const u8, opts: std.fmt.FormatOptions, writer: anytype) !void { + return self.cast().format(fmt, opts, writer); +} + +/// Get the current task, or null if not inside a task +pub inline fn current() ?@This() { + if (Frame.current()) |frame| { + return .{ .frame = frame }; + } else { + return null; + } +} + +pub inline fn state(self: @This(), T: type) T { + return @enumFromInt(self.cast().yield_state); +} + +pub inline fn yield(yield_state: anytype) void { + if (Frame.current()) |frame| { + frame.yield_state = @intFromEnum(yield_state); + if (frame.yield_state == 0) @panic("yield_state `0` is reserved"); + Frame.yield(.yield); + } else { + unreachable; + } +} + +pub inline fn wakeup(self: @This()) void { + self.cast().yield_state = 0; + self.cast().wakeup(.yield); +} + +pub const CompleteMode = Frame.CompleteMode; + +pub inline fn complete(self: @This(), mode: CompleteMode, comptime func: anytype) ReturnType(func) { + return self.cast().complete(mode, ReturnType(func)); +} + +pub fn generic(self: @This(), comptime func: anytype) Generic(func) { + return .{ .node = self.node }; +} + +pub fn Generic(comptime func: anytype) type { + return struct { + pub const Result = ReturnType(func); + pub const Fn = @TypeOf(func); + + frame: *align(@alignOf(Frame)) anyopaque, + comptime entry: Fn = func, + + fn cast(self: *@This()) *Frame { + return @ptrCast(self.frame); + } + + pub fn format(self: @This(), comptime fmt: []const u8, opts: std.fmt.FormatOptions, writer: anytype) !void { + return self.any().format(fmt, opts, writer); + } + + pub inline fn state(self: @This(), T: type) T { + return self.any().state(T); + } + + pub inline fn wakeup(self: @This()) void { + self.any().wakeup(); + } + + pub inline fn complete(self: @This(), mode: CompleteMode) Result { + return self.any().complete(mode, func); + } + + pub inline fn any(self: @This()) Task { + return .{ .frame = self.frame }; + } + }; +} diff --git a/src/coro/ThreadPool.zig b/src/coro/ThreadPool.zig new file mode 100644 index 0000000..511ef39 --- /dev/null +++ b/src/coro/ThreadPool.zig @@ -0,0 +1,83 @@ +const std = @import("std"); +const aio = @import("aio"); +const io = @import("io.zig"); +const Frame = @import("Frame.zig"); +const ReturnTypeWithError = @import("common.zig").ReturnTypeWithError; +const ReturnType = @import("common.zig").ReturnType; + +pool: std.Thread.Pool = undefined, +source: aio.EventSource = undefined, +num_tasks: std.atomic.Value(u32) = std.atomic.Value(u32).init(0), + +/// Spin up the pool, `allocator` is used to allocate the tasks +/// If `num_jobs` is zero, the thread count for the current CPU is used +pub fn start(self: *@This(), allocator: std.mem.Allocator, num_jobs: u32) !void { + self.* = .{ .pool = .{ .allocator = undefined, .threads = undefined }, .source = try aio.EventSource.init() }; + errdefer self.source.deinit(); + try self.pool.init(.{ .allocator = allocator, .n_jobs = if (num_jobs == 0) null else num_jobs }); +} + +pub fn deinit(self: *@This()) void { + self.pool.deinit(); + self.source.deinit(); + self.* = undefined; +} + +inline fn entrypoint(self: *@This(), completed: *bool, cancellation: *bool, comptime func: anytype, res: anytype, args: anytype) void { + _ = cancellation; // TODO + res.* = @call(.auto, func, args); + completed.* = true; + const n = self.num_tasks.load(.acquire); + for (0..n) |_| self.source.notify(); +} + +/// Yield until `func` finishes on another thread +pub fn yieldForCompletition(self: *@This(), func: anytype, args: anytype) ReturnTypeWithError(func, std.Thread.SpawnError) { + var completed: bool = false; + var res: ReturnType(func) = undefined; + _ = self.num_tasks.fetchAdd(1, .monotonic); + defer _ = self.num_tasks.fetchSub(1, .release); + var cancellation: bool = false; + try self.pool.spawn(entrypoint, .{ self, &completed, &cancellation, func, &res, args }); + while (!completed) { + const nerr = io.do(.{ + aio.WaitEventSource{ .source = &self.source, .link = .soft }, + }, if (cancellation) .io_cancel else .io) catch 1; + if (nerr > 0) { + if (Frame.current()) |frame| { + if (frame.canceled) { + cancellation = true; + continue; + } + } + // it's possible to end up here if aio implementation ran out of resources + // in case of io_uring the application managed to fill up the submission queue + // normally this should not happen, but as to not crash the program do a blocking wait + self.source.wait(); + } + } + return res; +} + +const ThreadPool = @This(); +const Scheduler = @import("Scheduler.zig"); + +test "ThreadPool" { + const Test = struct { + fn blocking() u32 { + std.time.sleep(1 * std.time.ns_per_s); + return 69; + } + fn task(pool: *ThreadPool) !void { + const ret = try pool.yieldForCompletition(blocking, .{}); + try std.testing.expectEqual(69, ret); + } + }; + var scheduler = try Scheduler.init(std.testing.allocator, .{}); + defer scheduler.deinit(); + var pool: ThreadPool = .{}; + try pool.start(std.testing.allocator, 0); + defer pool.deinit(); + for (0..10) |_| _ = try scheduler.spawn(Test.task, .{&pool}, .{}); + try scheduler.run(); +} diff --git a/src/coro/common.zig b/src/coro/common.zig new file mode 100644 index 0000000..1227a07 --- /dev/null +++ b/src/coro/common.zig @@ -0,0 +1,48 @@ +const std = @import("std"); +const aio = @import("aio"); +const Frame = @import("Frame.zig"); + +pub const WholeContext = struct { + num_operations: u16, + num_errors: u16 = 0, + frame: *Frame, +}; + +pub const OperationContext = struct { + whole: *WholeContext, + id: aio.Id = @enumFromInt(0), + completed: bool = false, +}; + +pub fn Link(comptime T: type, comptime field: []const u8, comptime container: enum { single, double }) type { + return struct { + pub inline fn cast(self: *@This()) *T { + switch (container) { + .single => { + const node: *std.SinglyLinkedList(@This()).Node = @alignCast(@fieldParentPtr("data", self)); + return @fieldParentPtr(field, node); + }, + .double => { + const node: *std.DoublyLinkedList(@This()).Node = @alignCast(@fieldParentPtr("data", self)); + return @fieldParentPtr(field, node); + }, + } + } + }; +} + +pub fn ReturnType(comptime func: anytype) type { + const base = @typeInfo(@TypeOf(func)).Fn.return_type.?; + return switch (@typeInfo(base)) { + .ErrorUnion => |eu| eu.error_set!eu.payload, + else => base, + }; +} + +pub fn ReturnTypeWithError(comptime func: anytype, comptime Error: type) type { + const base = @typeInfo(@TypeOf(func)).Fn.return_type.?; + return switch (@typeInfo(base)) { + .ErrorUnion => |eu| (Error || eu.error_set)!eu.payload, + else => Error!base, + }; +} diff --git a/src/coro/io.zig b/src/coro/io.zig new file mode 100644 index 0000000..afde7ee --- /dev/null +++ b/src/coro/io.zig @@ -0,0 +1,50 @@ +const std = @import("std"); +const aio = @import("aio"); +const Scheduler = @import("Scheduler.zig"); +const Frame = @import("Frame.zig"); +const common = @import("common.zig"); + +pub const Error = aio.Error || error{OperationCanceled}; + +pub fn do(operations: anytype, status: Frame.Status) Error!u16 { + std.debug.assert(status == .io or status == .io_cancel); + if (Frame.current()) |frame| { + if (frame.canceled) return error.OperationCanceled; + + var work = struct { ops: @TypeOf(operations) }{ .ops = operations }; + var whole: common.WholeContext = .{ .num_operations = operations.len, .frame = frame }; + var ctx_list: [operations.len]common.OperationContext = undefined; + + inline for (&work.ops, &ctx_list) |*op, *ctx| { + ctx.* = .{ .whole = &whole }; + op.userdata = @intFromPtr(ctx); + } + + try frame.scheduler.io.queue(work.ops); + Frame.yield(status); + + // check if this was a cancel + if (whole.num_operations > 0) { + var num_cancels: u16 = 0; + inline for (&ctx_list) |*ctx| { + if (!ctx.completed) { + // reuse the same ctx, we don't care about completation status anymore + try frame.scheduler.io.queue(aio.Cancel{ .id = ctx.id, .userdata = @intFromPtr(ctx) }); + num_cancels += 1; + } + } + + std.debug.assert(num_cancels > 0); + whole.num_operations += num_cancels; + Frame.yield(.io_cancel); + + if (whole.num_errors == 0) { + return error.OperationCanceled; + } + } + + return whole.num_errors; + } else { + return aio.complete(operations); + } +}