Skip to content

Commit

Permalink
Use Utf8StreamReader
Browse files Browse the repository at this point in the history
  • Loading branch information
neuecc committed Mar 27, 2024
1 parent 2f4e7b0 commit 11ffc2d
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 132 deletions.
2 changes: 1 addition & 1 deletion src/Claudia/Anthropic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async IAsyncEnumerable<IMessageStreamEvent> IMessages.CreateStreamAsync(MessageR
using var stream = await msg.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(ConfigureAwait);
#endif

var reader = new StreamMessageReader(stream, ConfigureAwait);
using var reader = new StreamMessageReader(stream, ConfigureAwait);

await foreach (var item in reader.ReadMessagesAsync(cancellationToken))
{
Expand Down
2 changes: 1 addition & 1 deletion src/Claudia/Claudia.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="System.IO.Pipelines" Version="8.0.0" />
<PackageReference Include="Utf8StreamReader" Version="0.0.4" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework) == 'net6.0' Or $(TargetFramework) == 'netstandard2.1'">
Expand Down
224 changes: 96 additions & 128 deletions src/Claudia/StreamMessageReader.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using System.IO.Pipelines;
using Cysharp.IO;
using System.Runtime.CompilerServices;
using System.Text.Json;

Expand All @@ -9,171 +7,141 @@ namespace Claudia;
// parser of server-sent events
// https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events

internal class StreamMessageReader
internal class StreamMessageReader : IDisposable
{
readonly PipeReader reader;
readonly Utf8StreamReader reader;
readonly bool configureAwait;
MessageStreamEventKind currentEvent;

public StreamMessageReader(Stream stream, bool configureAwait)
{
this.reader = PipeReader.Create(stream);
this.reader = new Utf8StreamReader(stream, leaveOpen: true) { ConfigureAwait = configureAwait };
this.configureAwait = configureAwait;
}

public async IAsyncEnumerable<IMessageStreamEvent> ReadMessagesAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
READ_AGAIN:
var readResult = await reader.ReadAsync(cancellationToken).ConfigureAwait(configureAwait);

if (!(readResult.IsCompleted | readResult.IsCanceled))
while (await reader.LoadIntoBufferAsync(cancellationToken).ConfigureAwait(configureAwait))
{
var buffer = readResult.Buffer;

while (TryReadData(ref buffer, out var streamEvent))
while (reader.TryReadLine(out var line))
{
var streamEvent = ParseLine(line);
if (streamEvent == null) continue;

yield return streamEvent;

if (streamEvent.TypeKind == MessageStreamEventKind.MessageStop)
{
yield break;
}
}

reader.AdvanceTo(buffer.Start, buffer.End); // examined is important
goto READ_AGAIN;
}
}

[SkipLocalsInit] // optimize stackalloc cost
bool TryReadData(ref ReadOnlySequence<byte> buffer, [NotNullWhen(true)] out IMessageStreamEvent? streamEvent)
IMessageStreamEvent? ParseLine(ReadOnlyMemory<byte> line)
{
var reader = new SequenceReader<byte>(buffer);
Span<byte> tempBytes = stackalloc byte[64]; // alloc temp
// line is these kinds
// event: event_name
// data: json
// (empty line)

while (reader.TryReadTo(out ReadOnlySequence<byte> line, (byte)'\n', advancePastDelimiter: true))
{
// line is these kinds
// event: event_name
// data: json
// (empty line)
var span = line.Span;

if (line.Length == 0)
{
continue; // next.
}
else if (line.FirstSpan[0] == 'e') // event
{
// Parse Event.
if (!line.IsSingleSegment)
{
line.CopyTo(tempBytes);
}
var span = line.IsSingleSegment ? line.FirstSpan : tempBytes.Slice(0, (int)line.Length);

var first = span[7]; // "event: [c|m|p|e]"

if (first == 'c') // content_block_start/delta/stop
{
switch (span[23]) // event: content_block_..[]
{
case (byte)'a': // st[a]rt
currentEvent = MessageStreamEventKind.ContentBlockStart;
break;
case (byte)'o': // st[o]p
currentEvent = MessageStreamEventKind.ContentBlockStop;
break;
case (byte)'l': // de[l]ta
currentEvent = MessageStreamEventKind.ContentBlockDelta;
break;
default:
break;
}
}
else if (first == 'm') // message_start/delta/stop
{
switch (span[17]) // event: message_..[]
{
case (byte)'a': // st[a]rt
currentEvent = MessageStreamEventKind.MessageStart;
break;
case (byte)'o': // st[o]p
currentEvent = MessageStreamEventKind.MessageStop;
break;
case (byte)'l': // de[l]ta
currentEvent = MessageStreamEventKind.MessageDelta;
break;
default:
break;
}
}
else if (first == 'p')
{
currentEvent = MessageStreamEventKind.Ping;
}
else if (first == 'e')
{
currentEvent = (MessageStreamEventKind)(-1);
}
else
{
// Unknown Event, Skip.
// throw new InvalidOperationException("Unknown Event. Line:" + Encoding.UTF8.GetString(line.ToArray()));
currentEvent = (MessageStreamEventKind)(-2);
}
if (span.Length == 0)
{
return null; // next.
}
else if (span[0] == 'e') // event
{
// Parse Event.
var first = span[7]; // "event: [c|m|p|e]"

continue;
}
else if (line.FirstSpan[0] == 'd') // data
if (first == 'c') // content_block_start/delta/stop
{
// Parse Data.
Utf8JsonReader jsonReader;
if (line.IsSingleSegment)
switch (span[23]) // event: content_block_..[]
{
jsonReader = new Utf8JsonReader(line.FirstSpan.Slice(6)); // skip data:
}
else
{
jsonReader = new Utf8JsonReader(line.Slice(6)); // ReadOnlySequence.Slice is slightly slow
}

switch (currentEvent)
{
case MessageStreamEventKind.Ping:
streamEvent = JsonSerializer.Deserialize<Ping>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'a': // st[a]rt
currentEvent = MessageStreamEventKind.ContentBlockStart;
break;
case MessageStreamEventKind.MessageStart:
streamEvent = JsonSerializer.Deserialize<MessageStart>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'o': // st[o]p
currentEvent = MessageStreamEventKind.ContentBlockStop;
break;
case MessageStreamEventKind.MessageDelta:
streamEvent = JsonSerializer.Deserialize<MessageDelta>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'l': // de[l]ta
currentEvent = MessageStreamEventKind.ContentBlockDelta;
break;
case MessageStreamEventKind.MessageStop:
streamEvent = JsonSerializer.Deserialize<MessageStop>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
default:
break;
case MessageStreamEventKind.ContentBlockStart:
streamEvent = JsonSerializer.Deserialize<ContentBlockStart>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
}
}
else if (first == 'm') // message_start/delta/stop
{
switch (span[17]) // event: message_..[]
{
case (byte)'a': // st[a]rt
currentEvent = MessageStreamEventKind.MessageStart;
break;
case MessageStreamEventKind.ContentBlockDelta:
streamEvent = JsonSerializer.Deserialize<ContentBlockDelta>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'o': // st[o]p
currentEvent = MessageStreamEventKind.MessageStop;
break;
case MessageStreamEventKind.ContentBlockStop:
streamEvent = JsonSerializer.Deserialize<ContentBlockStop>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'l': // de[l]ta
currentEvent = MessageStreamEventKind.MessageDelta;
break;
case (MessageStreamEventKind)(-1):
var error = JsonSerializer.Deserialize<ErrorResponseShape>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options);
throw new ClaudiaException(error!.ErrorResponse.ToErrorCode(), error.ErrorResponse.Type, error.ErrorResponse.Message);
default:
// unknown event, skip
goto END;
break;
}
}
else if (first == 'p')
{
currentEvent = MessageStreamEventKind.Ping;
}
else if (first == 'e')
{
currentEvent = (MessageStreamEventKind)(-1);
}
else
{
// Unknown Event, Skip.
// throw new InvalidOperationException("Unknown Event. Line:" + Encoding.UTF8.GetString(line.ToArray()));
currentEvent = (MessageStreamEventKind)(-2);
}

buffer = buffer.Slice(reader.Consumed);
return true;
return null; // continue
}
else if (span[0] == 'd') // data
{
// Parse Data.
var jsonReader = new Utf8JsonReader(span.Slice(6)); // skip data:
switch (currentEvent)
{
case MessageStreamEventKind.Ping:
return JsonSerializer.Deserialize<Ping>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.MessageStart:
return JsonSerializer.Deserialize<MessageStart>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.MessageDelta:
return JsonSerializer.Deserialize<MessageDelta>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.MessageStop:
return JsonSerializer.Deserialize<MessageStop>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.ContentBlockStart:
return JsonSerializer.Deserialize<ContentBlockStart>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.ContentBlockDelta:
return JsonSerializer.Deserialize<ContentBlockDelta>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.ContentBlockStop:
return JsonSerializer.Deserialize<ContentBlockStop>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (MessageStreamEventKind)(-1):
var error = JsonSerializer.Deserialize<ErrorResponseShape>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options);
throw new ClaudiaException(error!.ErrorResponse.ToErrorCode(), error.ErrorResponse.Type, error.ErrorResponse.Message);
default:
// unknown event, skip
return null;
}
}
END:
streamEvent = default;
buffer = buffer.Slice(reader.Consumed);
return false;

return null;
}

public void Dispose()
{
reader.Dispose();
}
}
4 changes: 2 additions & 2 deletions tests/Claudia.Tests/StreamMessageReaderTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public async Task Text()

var ms = new MemoryStream(Encoding.UTF8.GetBytes(data));

var reader = new StreamMessageReader(ms, true);
using var reader = new StreamMessageReader(ms, true);

var array = await reader.ReadMessagesAsync(CancellationToken.None)
.ToObservable()
Expand Down Expand Up @@ -154,7 +154,7 @@ public async Task WithNewLine()

var ms = new MemoryStream(Encoding.UTF8.GetBytes(data));

var reader = new StreamMessageReader(ms, true);
using var reader = new StreamMessageReader(ms, true);

var array = await reader.ReadMessagesAsync(CancellationToken.None)
.ToObservable()
Expand Down

0 comments on commit 11ffc2d

Please sign in to comment.