diff --git a/src/Claudia/Anthropic.cs b/src/Claudia/Anthropic.cs index cfebbd3..250647a 100644 --- a/src/Claudia/Anthropic.cs +++ b/src/Claudia/Anthropic.cs @@ -84,7 +84,7 @@ async IAsyncEnumerable IMessages.CreateStreamAsync(MessageR using var stream = await msg.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(ConfigureAwait); #endif - var reader = new StreamMessageReader(stream, ConfigureAwait); + using var reader = new StreamMessageReader(stream, ConfigureAwait); await foreach (var item in reader.ReadMessagesAsync(cancellationToken)) { diff --git a/src/Claudia/Claudia.csproj b/src/Claudia/Claudia.csproj index 1a851df..1b339a1 100644 --- a/src/Claudia/Claudia.csproj +++ b/src/Claudia/Claudia.csproj @@ -29,7 +29,7 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive - + diff --git a/src/Claudia/StreamMessageReader.cs b/src/Claudia/StreamMessageReader.cs index 441efc8..765a71a 100644 --- a/src/Claudia/StreamMessageReader.cs +++ b/src/Claudia/StreamMessageReader.cs @@ -1,6 +1,4 @@ -using System.Buffers; -using System.Diagnostics.CodeAnalysis; -using System.IO.Pipelines; +using Cysharp.IO; using System.Runtime.CompilerServices; using System.Text.Json; @@ -9,171 +7,141 @@ namespace Claudia; // parser of server-sent events // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events -internal class StreamMessageReader +internal class StreamMessageReader : IDisposable { - readonly PipeReader reader; + readonly Utf8StreamReader reader; readonly bool configureAwait; MessageStreamEventKind currentEvent; public StreamMessageReader(Stream stream, bool configureAwait) { - this.reader = PipeReader.Create(stream); + this.reader = new Utf8StreamReader(stream, leaveOpen: true) { ConfigureAwait = configureAwait }; this.configureAwait = configureAwait; } public async IAsyncEnumerable ReadMessagesAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - READ_AGAIN: - var readResult = await reader.ReadAsync(cancellationToken).ConfigureAwait(configureAwait); - - if (!(readResult.IsCompleted | readResult.IsCanceled)) + while (await reader.LoadIntoBufferAsync(cancellationToken).ConfigureAwait(configureAwait)) { - var buffer = readResult.Buffer; - - while (TryReadData(ref buffer, out var streamEvent)) + while (reader.TryReadLine(out var line)) { + var streamEvent = ParseLine(line); + if (streamEvent == null) continue; + yield return streamEvent; + if (streamEvent.TypeKind == MessageStreamEventKind.MessageStop) { yield break; } } - reader.AdvanceTo(buffer.Start, buffer.End); // examined is important - goto READ_AGAIN; } } - [SkipLocalsInit] // optimize stackalloc cost - bool TryReadData(ref ReadOnlySequence buffer, [NotNullWhen(true)] out IMessageStreamEvent? streamEvent) + IMessageStreamEvent? ParseLine(ReadOnlyMemory line) { - var reader = new SequenceReader(buffer); - Span tempBytes = stackalloc byte[64]; // alloc temp + // line is these kinds + // event: event_name + // data: json + // (empty line) - while (reader.TryReadTo(out ReadOnlySequence line, (byte)'\n', advancePastDelimiter: true)) - { - // line is these kinds - // event: event_name - // data: json - // (empty line) + var span = line.Span; - if (line.Length == 0) - { - continue; // next. - } - else if (line.FirstSpan[0] == 'e') // event - { - // Parse Event. - if (!line.IsSingleSegment) - { - line.CopyTo(tempBytes); - } - var span = line.IsSingleSegment ? line.FirstSpan : tempBytes.Slice(0, (int)line.Length); - - var first = span[7]; // "event: [c|m|p|e]" - - if (first == 'c') // content_block_start/delta/stop - { - switch (span[23]) // event: content_block_..[] - { - case (byte)'a': // st[a]rt - currentEvent = MessageStreamEventKind.ContentBlockStart; - break; - case (byte)'o': // st[o]p - currentEvent = MessageStreamEventKind.ContentBlockStop; - break; - case (byte)'l': // de[l]ta - currentEvent = MessageStreamEventKind.ContentBlockDelta; - break; - default: - break; - } - } - else if (first == 'm') // message_start/delta/stop - { - switch (span[17]) // event: message_..[] - { - case (byte)'a': // st[a]rt - currentEvent = MessageStreamEventKind.MessageStart; - break; - case (byte)'o': // st[o]p - currentEvent = MessageStreamEventKind.MessageStop; - break; - case (byte)'l': // de[l]ta - currentEvent = MessageStreamEventKind.MessageDelta; - break; - default: - break; - } - } - else if (first == 'p') - { - currentEvent = MessageStreamEventKind.Ping; - } - else if (first == 'e') - { - currentEvent = (MessageStreamEventKind)(-1); - } - else - { - // Unknown Event, Skip. - // throw new InvalidOperationException("Unknown Event. Line:" + Encoding.UTF8.GetString(line.ToArray())); - currentEvent = (MessageStreamEventKind)(-2); - } + if (span.Length == 0) + { + return null; // next. + } + else if (span[0] == 'e') // event + { + // Parse Event. + var first = span[7]; // "event: [c|m|p|e]" - continue; - } - else if (line.FirstSpan[0] == 'd') // data + if (first == 'c') // content_block_start/delta/stop { - // Parse Data. - Utf8JsonReader jsonReader; - if (line.IsSingleSegment) + switch (span[23]) // event: content_block_..[] { - jsonReader = new Utf8JsonReader(line.FirstSpan.Slice(6)); // skip data: - } - else - { - jsonReader = new Utf8JsonReader(line.Slice(6)); // ReadOnlySequence.Slice is slightly slow - } - - switch (currentEvent) - { - case MessageStreamEventKind.Ping: - streamEvent = JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case (byte)'a': // st[a]rt + currentEvent = MessageStreamEventKind.ContentBlockStart; break; - case MessageStreamEventKind.MessageStart: - streamEvent = JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case (byte)'o': // st[o]p + currentEvent = MessageStreamEventKind.ContentBlockStop; break; - case MessageStreamEventKind.MessageDelta: - streamEvent = JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case (byte)'l': // de[l]ta + currentEvent = MessageStreamEventKind.ContentBlockDelta; break; - case MessageStreamEventKind.MessageStop: - streamEvent = JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + default: break; - case MessageStreamEventKind.ContentBlockStart: - streamEvent = JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + } + } + else if (first == 'm') // message_start/delta/stop + { + switch (span[17]) // event: message_..[] + { + case (byte)'a': // st[a]rt + currentEvent = MessageStreamEventKind.MessageStart; break; - case MessageStreamEventKind.ContentBlockDelta: - streamEvent = JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case (byte)'o': // st[o]p + currentEvent = MessageStreamEventKind.MessageStop; break; - case MessageStreamEventKind.ContentBlockStop: - streamEvent = JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case (byte)'l': // de[l]ta + currentEvent = MessageStreamEventKind.MessageDelta; break; - case (MessageStreamEventKind)(-1): - var error = JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options); - throw new ClaudiaException(error!.ErrorResponse.ToErrorCode(), error.ErrorResponse.Type, error.ErrorResponse.Message); default: - // unknown event, skip - goto END; + break; } + } + else if (first == 'p') + { + currentEvent = MessageStreamEventKind.Ping; + } + else if (first == 'e') + { + currentEvent = (MessageStreamEventKind)(-1); + } + else + { + // Unknown Event, Skip. + // throw new InvalidOperationException("Unknown Event. Line:" + Encoding.UTF8.GetString(line.ToArray())); + currentEvent = (MessageStreamEventKind)(-2); + } - buffer = buffer.Slice(reader.Consumed); - return true; + return null; // continue + } + else if (span[0] == 'd') // data + { + // Parse Data. + var jsonReader = new Utf8JsonReader(span.Slice(6)); // skip data: + switch (currentEvent) + { + case MessageStreamEventKind.Ping: + return JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case MessageStreamEventKind.MessageStart: + return JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case MessageStreamEventKind.MessageDelta: + return JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case MessageStreamEventKind.MessageStop: + return JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case MessageStreamEventKind.ContentBlockStart: + return JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case MessageStreamEventKind.ContentBlockDelta: + return JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case MessageStreamEventKind.ContentBlockStop: + return JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!; + case (MessageStreamEventKind)(-1): + var error = JsonSerializer.Deserialize(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options); + throw new ClaudiaException(error!.ErrorResponse.ToErrorCode(), error.ErrorResponse.Type, error.ErrorResponse.Message); + default: + // unknown event, skip + return null; } } - END: - streamEvent = default; - buffer = buffer.Slice(reader.Consumed); - return false; + + return null; + } + + public void Dispose() + { + reader.Dispose(); } } diff --git a/tests/Claudia.Tests/StreamMessageReaderTest.cs b/tests/Claudia.Tests/StreamMessageReaderTest.cs index f8b866f..a6131c1 100644 --- a/tests/Claudia.Tests/StreamMessageReaderTest.cs +++ b/tests/Claudia.Tests/StreamMessageReaderTest.cs @@ -81,7 +81,7 @@ public async Task Text() var ms = new MemoryStream(Encoding.UTF8.GetBytes(data)); - var reader = new StreamMessageReader(ms, true); + using var reader = new StreamMessageReader(ms, true); var array = await reader.ReadMessagesAsync(CancellationToken.None) .ToObservable() @@ -154,7 +154,7 @@ public async Task WithNewLine() var ms = new MemoryStream(Encoding.UTF8.GetBytes(data)); - var reader = new StreamMessageReader(ms, true); + using var reader = new StreamMessageReader(ms, true); var array = await reader.ReadMessagesAsync(CancellationToken.None) .ToObservable()