Skip to content

Commit

Permalink
fix: Initial streaming support.
Browse files Browse the repository at this point in the history
  • Loading branch information
HavenDV committed Jul 22, 2024
1 parent 06faf09 commit bb9267c
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 186 deletions.
257 changes: 125 additions & 132 deletions src/libs/Anthropic/AnthropicApi.Streaming.cs
Original file line number Diff line number Diff line change
@@ -1,132 +1,125 @@
// using System.Net.Http;
// using System.Net.Http.Headers;
// using System.Runtime.CompilerServices;
//
// namespace Anthropic;
//
// public partial class AnthropicApi
// {
// /// <param name="body"></param>
// /// <param name="cancellationToken">A cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
// /// <summary>
// /// Creates a model response for the given chat conversation.
// /// </summary>
// /// <returns>OK</returns>
// /// <exception cref="ApiException">A server side error occurred.</exception>
// public virtual async IAsyncEnumerable<CompletionStreamResponseDelta> CreateCompletionAsStreamAsync(
// CreateCompletionRequest body,
// [EnumeratorCancellation] CancellationToken cancellationToken = default)
// {
// body = body ?? throw new ArgumentNullException(nameof(body));
// body.Stream = true;
//
// var urlBuilder = new System.Text.StringBuilder();
// urlBuilder.Append(BaseUrl.TrimEnd('/')).Append("/complete");
//
// var url = urlBuilder.ToString();
//
// using var request = new HttpRequestMessage(HttpMethod.Post, new Uri(url, UriKind.RelativeOrAbsolute))
// {
// Content = new StringContent(JsonSerializer.Serialize(body, _settings.Value))
// {
// Headers =
// {
// ContentType = MediaTypeHeaderValue.Parse("application/json"),
// }
// },
// Headers =
// {
// Accept =
// {
// MediaTypeWithQualityHeaderValue.Parse("text/event-stream"),
// },
// }
// };
//
// using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
// var headers = response.Headers.ToDictionary(
// static h => h.Key,
// static h => h.Value);
// foreach (var pair in response.Content.Headers)
// {
// headers[pair.Key] = pair.Value;
// }
//
// var status = (int)response.StatusCode;
// if (status != 200)
// {
// #if NET6_0_OR_GREATER
// var responseData = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
// #else
// var responseData = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
// #endif
//
// throw new ApiException(
// "The HTTP status code of the response was not expected (" + status + ").",
// status,
// responseData,
// headers,
// null);
// }
//
// #if NET6_0_OR_GREATER
// using var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
// #else
// using var responseStream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
// #endif
// using var reader = new StreamReader(responseStream);
//
// // Continuously read the stream until the end of it
// while (!reader.EndOfStream)
// {
// cancellationToken.ThrowIfCancellationRequested();
//
// #if NET7_0_OR_GREATER
// var line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false);
// #else
// var line = await reader.ReadLineAsync().ConfigureAwait(false);
// #endif
//
// // Skip empty lines
// if (string.IsNullOrEmpty(line))
// {
// continue;
// }
//
// var index = line.IndexOf('{');
// if (index >= 0)
// {
// line = line[index..];
// }
//
// if (line.StartsWith("event: ", StringComparison.OrdinalIgnoreCase))
// {
// continue;
// }
//
// CompletionStreamResponseDelta? block = null;
// try
// {
// // When the response is good, each line is a serializable CompletionCreateRequest
// block = JsonSerializer.Deserialize<CompletionStreamResponseDelta>(line);
// }
// catch (JsonException)
// {
// // When the API returns an error, it does not come back as a block, it returns a single character of text ("{").
// // In this instance, read through the rest of the response, which should be a complete object to parse.
// #if NET7_0_OR_GREATER
// line += await reader.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
// #else
// line += await reader.ReadToEndAsync().ConfigureAwait(false);
// #endif
// }
//
// if (block == null)
// {
// throw new ApiException("Response was null which was not expected.", status, line, headers, null);
// }
//
// yield return block;
// }
// }
// }
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;

// ReSharper disable RedundantNameQualifier
// ReSharper disable InconsistentNaming

namespace Anthropic;

public partial class AnthropicApi
{
/// <summary>
/// Create a Message<br/>
/// Send a structured list of input messages with text and/or image content, and the<br/>
/// model will generate the next message in the conversation.<br/>
/// The Messages API can be used for either single queries or stateless multi-turn<br/>
/// conversations.
/// </summary>
/// <param name="request"></param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
/// <exception cref="global::System.InvalidOperationException"></exception>
public async IAsyncEnumerable<global::Anthropic.MessageStreamEvent> CreateMessageAsStreamAsync(
global::Anthropic.CreateMessageRequest request,
[EnumeratorCancellation] global::System.Threading.CancellationToken cancellationToken = default)
{
request = request ?? throw new global::System.ArgumentNullException(nameof(request));
request.Stream = true;

PrepareArguments(
client: _httpClient);
PrepareCreateMessageArguments(
httpClient: _httpClient,
request: request);

using var httpRequest = new global::System.Net.Http.HttpRequestMessage(
method: global::System.Net.Http.HttpMethod.Post,
requestUri: new global::System.Uri(_httpClient.BaseAddress?.AbsoluteUri.TrimEnd('/') + "/messages", global::System.UriKind.RelativeOrAbsolute));
var __json = global::System.Text.Json.JsonSerializer.Serialize(request, global::Anthropic.SourceGenerationContext.Default.CreateMessageRequest);
httpRequest.Content = new global::System.Net.Http.StringContent(
content: __json,
encoding: global::System.Text.Encoding.UTF8,
mediaType: "application/json");
httpRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));

PrepareRequest(
client: _httpClient,
request: httpRequest);
PrepareCreateMessageRequest(
httpClient: _httpClient,
httpRequestMessage: httpRequest,
request: request);

using var response = await _httpClient.SendAsync(
request: httpRequest,
completionOption: global::System.Net.Http.HttpCompletionOption.ResponseContentRead,
cancellationToken: cancellationToken).ConfigureAwait(false);

ProcessResponse(
client: _httpClient,
response: response);
ProcessCreateMessageResponse(
httpClient: _httpClient,
httpResponseMessage: response);

#if NET6_0_OR_GREATER
using var __content = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
#else
using var __content = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
#endif
using var reader = new StreamReader(__content);

// Continuously read the stream until the end of it
while (!reader.EndOfStream)
{
cancellationToken.ThrowIfCancellationRequested();

#if NET7_0_OR_GREATER
var line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false);
#else
var line = await reader.ReadLineAsync().ConfigureAwait(false);
#endif

// Skip empty lines
if (string.IsNullOrEmpty(line))
{
continue;
}

var index = line.IndexOf('{');
if (index >= 0)
{
line = line[index..];
}

if (line.StartsWith("event: ", StringComparison.OrdinalIgnoreCase))
{
continue;
}

MessageStreamEvent? block = null;
try
{
// When the response is good, each line is a serializable CompletionCreateRequest
block = JsonSerializer.Deserialize(line, SourceGenerationContext.Default.NullableMessageStreamEvent);
}
catch (JsonException)
{
// When the API returns an error, it does not come back as a block, it returns a single character of text ("{").
// In this instance, read through the rest of the response, which should be a complete object to parse.
#if NET7_0_OR_GREATER
line += await reader.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
#else
line += await reader.ReadToEndAsync().ConfigureAwait(false);
#endif
}

if (block == null)
{
throw new HttpRequestException(line);
}

yield return block.Value;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,9 @@ public partial class JsonSerializerContextTypes
///
/// </summary>
public OpenApiSchema? OpenApiSchema { get; set; }

/// <summary>
///
/// </summary>
public MessageStreamEvent? MessageStreamEvent { get; set; }
}
35 changes: 0 additions & 35 deletions src/tests/Anthropic.IntegrationTests/Tests.Functions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,41 +53,6 @@
// }
// }
//
// private static void PrintMessages(List<Message> messages)
// {
// foreach (var message in messages)
// {
// if (message.IsTool)
// {
// Console.WriteLine($"> {message.Tool.Role}({message.Tool.ToolCallId}):");
// Console.WriteLine($"{message.Tool.Content}");
// }
// else if (message.IsSystem)
// {
// Console.WriteLine($"> {message.System.Role}: {message.System.Name}");
// Console.WriteLine($"{message.System.Content}");
// }
// else if (message.IsUser)
// {
// Console.WriteLine($"> {message.User.Role}: {message.User.Name}");
// Console.WriteLine($"{message.User.Content.Object}");
// }
// else if (message.IsAssistant)
// {
// Console.WriteLine($"> {message.Assistant.Role}: {message.Assistant.Name}");
// foreach (var call in message.Assistant.ToolCalls ?? Enumerable.Empty<ChatCompletionMessageToolCall>())
// {
// Console.WriteLine($"{call.Id}:");
// Console.WriteLine($"{call.Function.Name}({call.Function.Arguments})");
// }
// if (!string.IsNullOrWhiteSpace(message.Assistant.Content))
// {
// Console.WriteLine($"{message.Assistant.Content}");
// }
// }
// }
// }
//
// [TestMethod]
// public async Task Call()
// {
Expand Down
36 changes: 17 additions & 19 deletions src/tests/Anthropic.IntegrationTests/Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,21 @@ public async Task Tools()
response.StopReason.Should().Be(StopReason.ToolUse);
}

// [TestMethod]
// public async Task CreateChatCompletionAsStreamAsync()
// {
// using var api = GetAuthorizedApi();
//
// var enumerable = api.Chat.CreateChatCompletionAsStreamAsync(
// messages: new[]
// {
// "You are a helpful weather assistant.".AsSystemMessage(),
// "What's the weather like today?".AsUserMessage(),
// },
// model: CreateChatCompletionRequestModel.Gpt35Turbo,
// user: "tryAGI.Anthropic.IntegrationTests.Tests.CreateChatCompletion");
//
// await foreach (var response in enumerable)
// {
// Console.WriteLine(response.Choices.ElementAt(0).Delta.Content);
// }
// }
[TestMethod]
public async Task Streaming()
{
using var api = GetAuthorizedApi();

var enumerable = api.CreateMessageAsStreamAsync(new CreateMessageRequest
{
Model = CreateMessageRequestModel.Claude3Haiku20240307,
Messages = ["Once upon a time"],
MaxTokens = 250,
});

await foreach (var response in enumerable)
{
Console.WriteLine(response.ContentBlockDelta?.Delta.Value1?.Text);
}
}
}

0 comments on commit bb9267c

Please sign in to comment.