diff --git a/docs/pages/coro-blocking-code.mdx b/docs/pages/coro-blocking-code.mdx new file mode 100644 index 0000000..3c211de --- /dev/null +++ b/docs/pages/coro-blocking-code.mdx @@ -0,0 +1,31 @@ +# CORO API + +## Mixing blocking code + +Sometimes it's not feasible to rewrite blocking code so that it plays nice with the `coro.Scheduler`. +In that kind of scenario it is possible to use `coro.ThreadPool` to allow tasks to yield until blocking code +finishes on a worker thread. + +### Example + +```zig +fn blockingCode() u32 { + std.time.sleep(1 * std.time.ns_per_s); + return 69; +} + +fn task(pool: *ThreadPool) !void { + const ret = try pool.yieldForCompletition(blocking, .{}); + try std.testing.expectEqual(69, ret); +} + +var pool: ThreadPool = .{}; +defer pool.deinit(); // pool must always be destroyed before scheduler +try pool.start(std.testing.allocator, 0); + +var scheduler = try Scheduler.init(std.testing.allocator, .{}); +defer scheduler.deinit(); +_ = try scheduler.spawn(task, .{&pool}, .{}); +try scheduler.run(); +``` + diff --git a/docs/vocs.config.tsx b/docs/vocs.config.tsx index 544cc98..99cff96 100644 --- a/docs/vocs.config.tsx +++ b/docs/vocs.config.tsx @@ -77,6 +77,10 @@ export default defineConfig({ text: 'Context switches', link: '/coro-context-switches', }, + { + text: 'Mixing blocking code', + link: '/coro-blocking-code', + }, ], }, ], diff --git a/src/aio.zig b/src/aio.zig index b8d01d5..f4643c9 100644 --- a/src/aio.zig +++ b/src/aio.zig @@ -5,7 +5,6 @@ const std = @import("std"); pub const InitError = error{ - Overflow, OutOfMemory, PermissionDenied, ProcessQuotaExceeded, @@ -16,7 +15,7 @@ pub const InitError = error{ }; pub const QueueError = error{ - Overflow, + OutOfMemory, SubmissionQueueFull, }; @@ -101,7 +100,7 @@ pub inline fn multi(operations: anytype) (ImmediateError || error{SomeOperationF } /// Completes a single operation immediately, blocks until complete -pub inline fn single(operation: anytype) (ImmediateError || OperationError)!void { +pub inline fn single(operation: anytype) (ImmediateError || @TypeOf(operation).Error)!void { var op: @TypeOf(operation) = operation; var err: @TypeOf(operation).Error = error.Success; op.out_error = &err; @@ -109,6 +108,23 @@ pub inline fn single(operation: anytype) (ImmediateError || OperationError)!void if (err != error.Success) return err; } +pub const EventSource = struct { + native: IO.EventSource, + + pub inline fn init() InitError!@This() { + return .{ .native = try IO.EventSource.init() }; + } + + pub inline fn deinit(self: *@This()) void { + self.native.deinit(); + self.* = undefined; + } + + pub inline fn notify(self: *@This()) void { + self.native.notify(); + } +}; + const IO = switch (@import("builtin").target.os.tag) { .linux => @import("aio/linux.zig"), else => @compileError("unsupported os"), @@ -116,7 +132,6 @@ const IO = switch (@import("builtin").target.os.tag) { const ops = @import("aio/ops.zig"); pub const Id = ops.Id; -pub const OperationError = ops.Operation.Error; pub const Fsync = ops.Fsync; pub const Read = ops.Read; pub const Write = ops.Write; @@ -137,6 +152,9 @@ pub const SymlinkAt = ops.SymlinkAt; pub const ChildExit = ops.ChildExit; pub const Socket = ops.Socket; pub const CloseSocket = ops.CloseSocket; +pub const NotifyEventSource = ops.NotifyEventSource; +pub const WaitEventSource = ops.WaitEventSource; +pub const CloseEventSource = ops.CloseEventSource; test "shared outputs" { var tmp = std.testing.tmpDir(.{}); @@ -252,12 +270,11 @@ test "Timeout" { test "LinkTimeout" { var err: Timeout.Error = undefined; var expired: bool = undefined; - const res = try complete(.{ + const num_errors = try complete(.{ Timeout{ .ns = 2 * std.time.ns_per_s, .out_error = &err, .link_next = true }, LinkTimeout{ .ns = 1 * std.time.ns_per_s, .out_expired = &expired }, }); - try std.testing.expectEqual(2, res.num_completed); - try std.testing.expectEqual(1, res.num_errors); + try std.testing.expectEqual(1, num_errors); try std.testing.expectEqual(error.OperationCanceled, err); try std.testing.expectEqual(true, expired); } @@ -356,3 +373,12 @@ test "Socket" { }); try single(CloseSocket{ .socket = socket }); } + +test "EventSource" { + const source = try EventSource.init(); + try multi(.{ + NotifyEventSource{ .source = source }, + WaitEventSource{ .source = source, .link_next = true }, + CloseEventSource{ .source = source }, + }); +} diff --git a/src/aio/linux.zig b/src/aio/linux.zig index 0282ccf..f8e77b6 100644 --- a/src/aio/linux.zig +++ b/src/aio/linux.zig @@ -1,13 +1,36 @@ const std = @import("std"); const aio = @import("../aio.zig"); const Operation = @import("ops.zig").Operation; -const ErrorUnion = @import("ops.zig").ErrorUnion; + +pub const EventSource = struct { + fd: std.posix.fd_t, + + pub inline fn init() !@This() { + return .{ + .fd = std.posix.eventfd(0, std.os.linux.EFD.CLOEXEC) catch |err| return switch (err) { + error.SystemResources => error.SystemResources, + error.ProcessFdQuotaExceeded => error.ProcessQuotaExceeded, + error.SystemFdQuotaExceeded => error.SystemQuotaExceeded, + error.Unexpected => error.Unexpected, + }, + }; + } + + pub inline fn deinit(self: *@This()) void { + std.posix.close(self.fd); + self.* = undefined; + } + + pub inline fn notify(self: *@This()) void { + _ = std.posix.write(self.fd, &std.mem.toBytes(@as(u64, 1))) catch unreachable; + } +}; io: std.os.linux.IoUring, ops: Pool(Operation.Union, u16), pub fn init(allocator: std.mem.Allocator, n: u16) aio.InitError!@This() { - const n2 = try std.math.ceilPowerOfTwo(u16, n); + const n2 = std.math.ceilPowerOfTwo(u16, n) catch return error.SystemQuotaExceeded; var io = try uring_init(n2); errdefer io.deinit(); const ops = try Pool(Operation.Union, u16).init(allocator, n2); @@ -22,7 +45,7 @@ pub fn deinit(self: *@This(), allocator: std.mem.Allocator) void { } inline fn queueOperation(self: *@This(), op: anytype) aio.QueueError!u16 { - const n = self.ops.next() orelse return error.Overflow; + const n = self.ops.next() orelse return error.OutOfMemory; try uring_queue(&self.io, op, n); const tag = @tagName(comptime Operation.tagFromPayloadType(@TypeOf(op.*))); return self.ops.add(@unionInit(Operation.Union, tag, op.*)) catch unreachable; @@ -63,7 +86,7 @@ pub fn complete(self: *@This(), mode: aio.Dynamic.CompletionMode) aio.Completion } pub fn immediate(comptime len: u16, work: anytype) aio.ImmediateError!u16 { - var io = try uring_init(try std.math.ceilPowerOfTwo(u16, len)); + var io = try uring_init(std.math.ceilPowerOfTwo(u16, len) catch return error.SystemQuotaExceeded); defer io.deinit(); inline for (&work.ops, 0..) |*op, idx| try uring_queue(&io, op, idx); var num = try uring_submit(&io); @@ -126,6 +149,9 @@ fn convertOpenFlags(flags: std.fs.File.OpenFlags) std.posix.O { } inline fn uring_queue(io: *std.os.linux.IoUring, op: anytype, user_data: u64) aio.QueueError!void { + const Trash = struct { + var u_64: u64 align(1) = undefined; + }; const RENAME_NOREPLACE = 1 << 0; var sqe = switch (comptime Operation.tagFromPayloadType(@TypeOf(op.*))) { .fsync => try io.fsync(user_data, op.file.handle, 0), @@ -161,6 +187,9 @@ inline fn uring_queue(io: *std.os.linux.IoUring, op: anytype, user_data: u64) ai .child_exit => try io.waitid(user_data, .PID, op.child, @constCast(&op._), std.posix.W.EXITED, 0), .socket => try io.socket(user_data, op.domain, op.flags, op.protocol, 0), .close_socket => try io.close(user_data, op.socket), + .notify_event_source => try io.write(user_data, op.source.native.fd, &std.mem.toBytes(@as(u64, 1)), 0), + .wait_event_source => try io.read(user_data, op.source.native.fd, .{ .buffer = std.mem.asBytes(&Trash.u_64) }, 0), + .close_event_source => try io.close(user_data, op.source.native.fd), }; if (op.link_next) sqe.flags |= std.os.linux.IOSQE_IO_LINK; if (@hasField(@TypeOf(op.*), "out_id")) { @@ -355,6 +384,7 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe) else => std.posix.unexpectedErrno(err), }, .close_file, .close_dir, .close_socket => unreachable, + .notify_event_source, .wait_event_source, .close_event_source => unreachable, .timeout => switch (err) { .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, .TIME => error.Success, @@ -493,6 +523,7 @@ inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe) }, .open_at => op.out_file.handle = cqe.res, .close_file, .close_dir, .close_socket => {}, + .notify_event_source, .wait_event_source, .close_event_source => {}, .timeout, .link_timeout => {}, .cancel => {}, .rename_at, .unlink_at, .mkdir_at, .symlink_at => {}, diff --git a/src/aio/ops.zig b/src/aio/ops.zig index 84b397a..42a4a69 100644 --- a/src/aio/ops.zig +++ b/src/aio/ops.zig @@ -1,5 +1,6 @@ const std = @import("std"); const builtin = @import("builtin"); +const aio = @import("../aio.zig"); // Virtual linked actions are possible with `nop` under io_uring :thinking: @@ -261,6 +262,30 @@ pub const CloseSocket = struct { link_next: bool = false, }; +pub const NotifyEventSource = struct { + pub const Error = SharedError; + source: aio.EventSource, + out_error: ?*Error = null, + counter: Counter = .nop, + link_next: bool = false, +}; + +pub const WaitEventSource = struct { + pub const Error = SharedError; + source: aio.EventSource, + out_error: ?*Error = null, + counter: Counter = .nop, + link_next: bool = false, +}; + +pub const CloseEventSource = struct { + pub const Error = SharedError; + source: aio.EventSource, + out_error: ?*Error = null, + counter: Counter = .nop, + link_next: bool = false, +}; + pub const Operation = enum { fsync, read, @@ -282,6 +307,9 @@ pub const Operation = enum { child_exit, socket, close_socket, + notify_event_source, + wait_event_source, + close_event_source, pub const map = std.enums.EnumMap(@This(), type).init(.{ .fsync = Fsync, @@ -304,6 +332,9 @@ pub const Operation = enum { .child_exit = ChildExit, .socket = Socket, .close_socket = CloseSocket, + .notify_event_source = NotifyEventSource, + .wait_event_source = WaitEventSource, + .close_event_source = CloseEventSource, }); pub fn tagFromPayloadType(comptime Op: type) @This() { @@ -316,12 +347,6 @@ pub const Operation = enum { unreachable; } - pub const Error = blk: { - var set = error{}; - for (Operation.map.values) |v| set = set || v.Error; - break :blk set; - }; - pub const Union = blk: { var fields: []const std.builtin.Type.UnionField = &.{}; for (Operation.map.values, 0..) |v, idx| fields = fields ++ .{.{ diff --git a/src/coro.zig b/src/coro.zig index 129c708..9a57c76 100644 --- a/src/coro.zig +++ b/src/coro.zig @@ -91,7 +91,7 @@ pub const io = struct { /// Completes a single operation immediately, blocks the coroutine until complete /// The IO operation can be cancelled by calling `wakeupFromIo`, or doing `aio.Cancel` - pub fn single(operation: anytype) (aio.ImmediateError || aio.OperationError)!void { + pub inline fn single(operation: anytype) (aio.ImmediateError || @TypeOf(operation).Error)!void { var op: @TypeOf(operation) = operation; var err: @TypeOf(operation).Error = error.Success; op.out_error = &err; @@ -244,7 +244,7 @@ pub const Scheduler = struct { self.* = undefined; } - fn entrypoint(self: *@This(), comptime func: anytype, args: anytype) void { + inline fn entrypoint(self: *@This(), comptime func: anytype, args: anytype) void { if (@typeInfo(@typeInfo(@TypeOf(func)).Fn.return_type.?) == .ErrorUnion) { @call(.auto, func, args) catch |err| options.error_handler(err); } else { @@ -331,3 +331,94 @@ pub const Scheduler = struct { while (self.tasks.len > 0) try self.tick(.blocking); } }; + +/// Synchronizes blocking tasks on a thread pool +pub const ThreadPool = struct { + pool: std.Thread.Pool = undefined, + + /// Spin up the pool, `allocator` is used to allocate the tasks + /// If `num_jobs` is zero, the thread count for the current CPU is used + pub fn start(self: *@This(), allocator: std.mem.Allocator, num_jobs: u32) !void { + self.* = .{ .pool = .{ .allocator = undefined, .threads = undefined } }; + try self.pool.init(.{ .allocator = allocator, .n_jobs = if (num_jobs == 0) null else num_jobs }); + } + + pub fn deinit(self: *@This()) void { + self.pool.deinit(); + self.* = undefined; + } + + inline fn entrypoint(source: *aio.EventSource, comptime func: anytype, ret: anytype, args: anytype) void { + ret.* = @call(.auto, func, args); + source.notify(); + } + + const Error = error{ + OutOfMemory, + SystemResources, + ProcessQuotaExceeded, + SystemQuotaExceeded, + Unexpected, + SomeOperationFailed, + } || aio.ImmediateError || aio.Read.Error; + + fn ReturnType(comptime Func: type) type { + const base = @typeInfo(Func).Fn.return_type.?; + return switch (@typeInfo(base)) { + .ErrorUnion => |eu| (Error || eu.error_set)!eu.payload, + else => Error!base, + }; + } + + /// Yield until `func` finishes on another thread + pub fn yieldForCompletition(self: *@This(), func: anytype, args: anytype) ReturnType(@TypeOf(func)) { + if (Fiber.current()) |_| { + var ret: @typeInfo(@TypeOf(func)).Fn.return_type.? = undefined; + var source = try aio.EventSource.init(); + errdefer source.deinit(); + try self.pool.spawn(entrypoint, .{ &source, func, &ret, args }); + try io.multi(.{ + aio.WaitEventSource{ .source = source, .link_next = true }, + aio.CloseEventSource{ .source = source }, + }); + return ret; + } else { + unreachable; // yieldForCompletition can only be used from a task + } + } +}; + +test "ThreadPool" { + const Test = struct { + const Yield = enum { + task2_free, + }; + + fn task1(task1_done: *bool) void { + yield(Yield.task2_free); + task1_done.* = true; + } + + fn blocking() u32 { + std.time.sleep(1 * std.time.ns_per_s); + return 69; + } + + fn task2(t1: Task, pool: *ThreadPool, task1_done: *bool) !void { + const ret = try pool.yieldForCompletition(blocking, .{}); + try std.testing.expectEqual(69, ret); + try std.testing.expectEqual(task1_done.*, false); + wakeupFromState(t1, Yield.task2_free); + try std.testing.expectEqual(task1_done.*, true); + } + }; + var pool: ThreadPool = .{}; + defer pool.deinit(); + try pool.start(std.testing.allocator, 0); + var scheduler = try Scheduler.init(std.testing.allocator, .{}); + defer scheduler.deinit(); + var task1_done: bool = false; + const task1 = try scheduler.spawn(Test.task1, .{&task1_done}, .{}); + _ = try scheduler.spawn(Test.task2, .{ task1, &pool, &task1_done }, .{}); + try scheduler.run(); +}