Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Utf8StreamReader #9

Merged
merged 2 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sandbox/BedrockBlazorApp1/Components/Pages/Home.razor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace BedrockBlazorApp1.Components.Pages;
public partial class Home
{
[Inject]
public AmazonBedrockRuntimeClient BedrockClient { get; init; }
private BedrockAnthropicClient anthropic;
public required AmazonBedrockRuntimeClient BedrockClient { get; init; }
private BedrockAnthropicClient anthropic = default!;

double temperature = 1.0;
string textInput = "";
Expand Down
1 change: 0 additions & 1 deletion sandbox/BedrockConsoleApp/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using System.Buffers;
using System.Collections.Generic;
using System.Formats.Asn1;
using System.IO.Pipelines;
using System.Reflection.PortableExecutable;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices.ObjectiveC;
Expand Down
2 changes: 1 addition & 1 deletion src/Claudia/Anthropic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async IAsyncEnumerable<IMessageStreamEvent> IMessages.CreateStreamAsync(MessageR
using var stream = await msg.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(ConfigureAwait);
#endif

var reader = new StreamMessageReader(stream, ConfigureAwait);
using var reader = new StreamMessageReader(stream, ConfigureAwait);

await foreach (var item in reader.ReadMessagesAsync(cancellationToken))
{
Expand Down
2 changes: 1 addition & 1 deletion src/Claudia/Claudia.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="System.IO.Pipelines" Version="8.0.0" />
<PackageReference Include="Utf8StreamReader" Version="0.0.4" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework) == 'net6.0' Or $(TargetFramework) == 'netstandard2.1'">
Expand Down
225 changes: 96 additions & 129 deletions src/Claudia/StreamMessageReader.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using System.IO.Pipelines;
using Cysharp.IO;
using System.Runtime.CompilerServices;
using System.Text.Json;

Expand All @@ -9,171 +7,140 @@ namespace Claudia;
// parser of server-sent events
// https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events

internal class StreamMessageReader
internal class StreamMessageReader : IDisposable
{
readonly PipeReader reader;
readonly Utf8StreamReader reader;
readonly bool configureAwait;
MessageStreamEventKind currentEvent;

public StreamMessageReader(Stream stream, bool configureAwait)
{
this.reader = PipeReader.Create(stream);
this.reader = new Utf8StreamReader(stream, leaveOpen: true) { ConfigureAwait = configureAwait };
this.configureAwait = configureAwait;
}

public async IAsyncEnumerable<IMessageStreamEvent> ReadMessagesAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
READ_AGAIN:
var readResult = await reader.ReadAsync(cancellationToken).ConfigureAwait(configureAwait);

if (!(readResult.IsCompleted | readResult.IsCanceled))
while (await reader.LoadIntoBufferAsync(cancellationToken).ConfigureAwait(configureAwait))
{
var buffer = readResult.Buffer;

while (TryReadData(ref buffer, out var streamEvent))
while (reader.TryReadLine(out var line))
{
var streamEvent = ParseLine(line);
if (streamEvent == null) continue;

yield return streamEvent;

if (streamEvent.TypeKind == MessageStreamEventKind.MessageStop)
{
yield break;
}
}

reader.AdvanceTo(buffer.Start, buffer.End); // examined is important
goto READ_AGAIN;
}
}

[SkipLocalsInit] // optimize stackalloc cost
bool TryReadData(ref ReadOnlySequence<byte> buffer, [NotNullWhen(true)] out IMessageStreamEvent? streamEvent)
IMessageStreamEvent? ParseLine(ReadOnlyMemory<byte> line)
{
var reader = new SequenceReader<byte>(buffer);
Span<byte> tempBytes = stackalloc byte[64]; // alloc temp

while (reader.TryReadTo(out ReadOnlySequence<byte> line, (byte)'\n', advancePastDelimiter: true))
{
// line is these kinds
// event: event_name
// data: json
// (empty line)

if (line.Length == 0)
{
continue; // next.
}
else if (line.FirstSpan[0] == 'e') // event
{
// Parse Event.
if (!line.IsSingleSegment)
{
line.CopyTo(tempBytes);
}
var span = line.IsSingleSegment ? line.FirstSpan : tempBytes.Slice(0, (int)line.Length);
// line is these kinds
// event: event_name
// data: json
// (empty line)

var first = span[7]; // "event: [c|m|p|e]"
var span = line.Span;

if (first == 'c') // content_block_start/delta/stop
{
switch (span[23]) // event: content_block_..[]
{
case (byte)'a': // st[a]rt
currentEvent = MessageStreamEventKind.ContentBlockStart;
break;
case (byte)'o': // st[o]p
currentEvent = MessageStreamEventKind.ContentBlockStop;
break;
case (byte)'l': // de[l]ta
currentEvent = MessageStreamEventKind.ContentBlockDelta;
break;
default:
break;
}
}
else if (first == 'm') // message_start/delta/stop
{
switch (span[17]) // event: message_..[]
{
case (byte)'a': // st[a]rt
currentEvent = MessageStreamEventKind.MessageStart;
break;
case (byte)'o': // st[o]p
currentEvent = MessageStreamEventKind.MessageStop;
break;
case (byte)'l': // de[l]ta
currentEvent = MessageStreamEventKind.MessageDelta;
break;
default:
break;
}
}
else if (first == 'p')
{
currentEvent = MessageStreamEventKind.Ping;
}
else if (first == 'e')
{
currentEvent = (MessageStreamEventKind)(-1);
}
else
{
// Unknown Event, Skip.
// throw new InvalidOperationException("Unknown Event. Line:" + Encoding.UTF8.GetString(line.ToArray()));
currentEvent = (MessageStreamEventKind)(-2);
}
if (span.Length == 0)
{
return null; // next.
}
else if (span[0] == 'e') // event
{
// Parse Event.
var first = span[7]; // "event: [c|m|p|e]"

continue;
}
else if (line.FirstSpan[0] == 'd') // data
if (first == 'c') // content_block_start/delta/stop
{
// Parse Data.
Utf8JsonReader jsonReader;
if (line.IsSingleSegment)
{
jsonReader = new Utf8JsonReader(line.FirstSpan.Slice(6)); // skip data:
}
else
switch (span[23]) // event: content_block_..[]
{
jsonReader = new Utf8JsonReader(line.Slice(6)); // ReadOnlySequence.Slice is slightly slow
}

switch (currentEvent)
{
case MessageStreamEventKind.Ping:
streamEvent = JsonSerializer.Deserialize<Ping>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'a': // st[a]rt
currentEvent = MessageStreamEventKind.ContentBlockStart;
break;
case MessageStreamEventKind.MessageStart:
streamEvent = JsonSerializer.Deserialize<MessageStart>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'o': // st[o]p
currentEvent = MessageStreamEventKind.ContentBlockStop;
break;
case MessageStreamEventKind.MessageDelta:
streamEvent = JsonSerializer.Deserialize<MessageDelta>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'l': // de[l]ta
currentEvent = MessageStreamEventKind.ContentBlockDelta;
break;
case MessageStreamEventKind.MessageStop:
streamEvent = JsonSerializer.Deserialize<MessageStop>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
default:
break;
case MessageStreamEventKind.ContentBlockStart:
streamEvent = JsonSerializer.Deserialize<ContentBlockStart>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
}
}
else if (first == 'm') // message_start/delta/stop
{
switch (span[17]) // event: message_..[]
{
case (byte)'a': // st[a]rt
currentEvent = MessageStreamEventKind.MessageStart;
break;
case MessageStreamEventKind.ContentBlockDelta:
streamEvent = JsonSerializer.Deserialize<ContentBlockDelta>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'o': // st[o]p
currentEvent = MessageStreamEventKind.MessageStop;
break;
case MessageStreamEventKind.ContentBlockStop:
streamEvent = JsonSerializer.Deserialize<ContentBlockStop>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (byte)'l': // de[l]ta
currentEvent = MessageStreamEventKind.MessageDelta;
break;
case (MessageStreamEventKind)(-1):
var error = JsonSerializer.Deserialize<ErrorResponseShape>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options);
throw new ClaudiaException(error!.ErrorResponse.ToErrorCode(), error.ErrorResponse.Type, error.ErrorResponse.Message);
default:
// unknown event, skip
goto END;
break;
}
}
else if (first == 'p')
{
currentEvent = MessageStreamEventKind.Ping;
}
else if (first == 'e')
{
currentEvent = (MessageStreamEventKind)(-1);
}
else
{
// Unknown Event, Skip.
// throw new InvalidOperationException("Unknown Event. Line:" + Encoding.UTF8.GetString(line.ToArray()));
currentEvent = (MessageStreamEventKind)(-2);
}

buffer = buffer.Slice(reader.Consumed);
return true;
return null; // continue
}
else if (span[0] == 'd') // data
{
// Parse Data.
var jsonReader = new Utf8JsonReader(span.Slice(6)); // skip data:
switch (currentEvent)
{
case MessageStreamEventKind.Ping:
return JsonSerializer.Deserialize<Ping>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.MessageStart:
return JsonSerializer.Deserialize<MessageStart>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.MessageDelta:
return JsonSerializer.Deserialize<MessageDelta>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.MessageStop:
return JsonSerializer.Deserialize<MessageStop>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.ContentBlockStart:
return JsonSerializer.Deserialize<ContentBlockStart>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.ContentBlockDelta:
return JsonSerializer.Deserialize<ContentBlockDelta>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case MessageStreamEventKind.ContentBlockStop:
return JsonSerializer.Deserialize<ContentBlockStop>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options)!;
case (MessageStreamEventKind)(-1):
var error = JsonSerializer.Deserialize<ErrorResponseShape>(ref jsonReader, AnthropicJsonSerialzierContext.Default.Options);
throw new ClaudiaException(error!.ErrorResponse.ToErrorCode(), error.ErrorResponse.Type, error.ErrorResponse.Message);
default:
// unknown event, skip
return null;
}
}
END:
streamEvent = default;
buffer = buffer.Slice(reader.Consumed);
return false;

return null;
}

public void Dispose()
{
reader.Dispose();
}
}
4 changes: 2 additions & 2 deletions tests/Claudia.Tests/StreamMessageReaderTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public async Task Text()

var ms = new MemoryStream(Encoding.UTF8.GetBytes(data));

var reader = new StreamMessageReader(ms, true);
using var reader = new StreamMessageReader(ms, true);

var array = await reader.ReadMessagesAsync(CancellationToken.None)
.ToObservable()
Expand Down Expand Up @@ -154,7 +154,7 @@ public async Task WithNewLine()

var ms = new MemoryStream(Encoding.UTF8.GetBytes(data));

var reader = new StreamMessageReader(ms, true);
using var reader = new StreamMessageReader(ms, true);

var array = await reader.ReadMessagesAsync(CancellationToken.None)
.ToObservable()
Expand Down
Loading