Skip to content

Commit

Permalink
wip retry improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
neuecc committed Mar 13, 2024
1 parent f78899d commit 26df93e
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 20 deletions.
1 change: 1 addition & 0 deletions sandbox/ConsoleApp1/ConsoleApp1.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<NoWarn>1591</NoWarn>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>

Expand Down
69 changes: 68 additions & 1 deletion sandbox/ConsoleApp1/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -397,4 +397,71 @@ public static int Sum(int x, int y)
{
return x + y;
}
}

// Mock
public const string SystemPrompt = """
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
<tool_description>
<tool_name>Sum</tool_name>
<description>
foobarbaz
</description>
<parameters>
<parameter>
<name>x</name>
<type>int</type>
<description>p1</description>
</parameter>
<parameter>
<name>y</name>
<type>int</type>
<description>p2</description>
</parameter>
</parameters>
</tool_description>
</tools>
""";

public static class PromptXml
{
public const string Sum = """
<tool_description>
<tool_name>Sum</tool_name>
<description>
foobarbaz
</description>
<parameters>
<parameter>
<name>x</name>
<type>int</type>
<description>p1</description>
</parameter>
<parameter>
<name>y</name>
<type>int</type>
<description>p2</description>
</parameter>
</parameters>
</tool_description>
""";
}

//public static ValueTask<object> InvokeAsync()
//{
//}
}

51 changes: 39 additions & 12 deletions src/Claudia.FunctionGenerator/Emitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@ namespace Claudia.FunctionGenerator;

public class Emitter
{
const string SystemPromptHead = """
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
""";

private SourceProductionContext context;
private ParseResult[] result;

Expand All @@ -20,31 +35,43 @@ internal void Emit()
{
var type = item.Key!;

var typeDocComment = type.GetDocumentationCommentXml();

var name = type.Name;
var description = ((string)XElement.Parse(typeDocComment).Element("summary")).Trim();

var tools = new List<XElement>();

foreach (var method in item)
{

var docComment = method.MethodSymbol.GetDocumentationCommentXml();
var xml = XElement.Parse(docComment);

var description = ((string)xml.Element("summary")).Trim();

var parameters = new List<XElement>();
foreach (var p in xml.Elements("param"))
{
var paramDescription = ((string)p).Trim();

//new XElement("parameter",
// new XElement(
// type retrieve from method symbol
var name = p.Attribute("name").Value.Trim();
var paramType = method.MethodSymbol.Parameters.First(x => x.Name == name).Type.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat);

parameters.Add(new XElement("parameter",
new XElement("name", name),
new XElement("type", paramType),
new XElement("description", paramDescription)));
}

var tool = new XElement("tool_description",
new XElement("tool_name", method.MethodSymbol.Name),
new XElement("description", description),
new XElement("parameters", parameters));

tools.Add(tool);
}

var finalXml = new XElement("tools", tools);




new XElement("tool_description",
new XElement("tool_name", name),
new XElement("description", description),
new XElement("parameters", null));

}

Expand Down
62 changes: 55 additions & 7 deletions src/Claudia/Anthropic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async Task<MessagesResponse> IMessages.CreateAsync(MessageRequest request, Reque
request.Stream = null;
using var msg = await SendRequestAsync(request, overrideOptions, cancellationToken).ConfigureAwait(false);

var result = await RequestWithCancelAsync(msg, cancellationToken, overrideOptions, false, static (x, ct) => x.Content.ReadFromJsonAsync<MessagesResponse>(AnthropicJsonSerialzierContext.Default.Options, ct)).ConfigureAwait(false);
var result = await RequestWithAsync(msg, cancellationToken, overrideOptions, static (x, ct) => x.Content.ReadFromJsonAsync<MessagesResponse>(AnthropicJsonSerialzierContext.Default.Options, ct), null).ConfigureAwait(false);
return result!;
}

Expand Down Expand Up @@ -94,7 +94,7 @@ async Task<HttpResponseMessage> SendRequestAsync(MessageRequest request, Request
}

// use ResponseHeadersRead to ignore buffering response.
var msg = await RequestWithCancelAsync((httpClient, (bytes, requestUri, overrideOptions, ApiKey)), cancellationToken, overrideOptions, true, static (x, ct) =>
var msg = await RequestWithAsync((httpClient, (bytes, requestUri, overrideOptions, ApiKey)), cancellationToken, overrideOptions, static (x, ct) =>
{
// for retry, create new HttpRequestMessage per request.
var state = x.Item2;
Expand All @@ -115,16 +115,44 @@ async Task<HttpResponseMessage> SendRequestAsync(MessageRequest request, Request

message.Content = new ByteArrayContent(state.bytes);
return x.httpClient.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, ct);
}, static response =>
{
// Same logic of official sdk's shouldRetry
// https://github.com/anthropics/anthropic-sdk-typescript/blob/104562c3c2164d50da105fed6cfb400b118503d0/src/core.ts#L521
if (response.Headers.TryGetValues("x-should-retry", out var values))
{
foreach (var item in values)
{
if (item == "true") return true;
else if (item == "false") return false;
}
}

var status = (int)response.StatusCode;

// Retry on request timeouts.
if (status == 408) return true;

// Retry on lock timeouts.
if (status == 409) return true;

// Retry on rate limits.
if (status == 429) return true;

// Retry internal errors.
if (status >= 500) return true;

return false;
}).ConfigureAwait(false);

var statusCode = (int)msg.StatusCode;

switch (statusCode)
{
case 200:
return msg;
return msg!;
default:
var shape = await RequestWithCancelAsync(msg, cancellationToken, overrideOptions, false, static (x, ct) => x.Content.ReadFromJsonAsync<ErrorResponseShape>(AnthropicJsonSerialzierContext.Default.Options, ct)).ConfigureAwait(false);
var shape = await RequestWithAsync(msg, cancellationToken, overrideOptions, static (x, ct) => x.Content.ReadFromJsonAsync<ErrorResponseShape>(AnthropicJsonSerialzierContext.Default.Options, ct), null).ConfigureAwait(false);

var error = shape!.ErrorResponse;
var errorMsg = error.Message;
Expand All @@ -137,9 +165,10 @@ async Task<HttpResponseMessage> SendRequestAsync(MessageRequest request, Request
}
}

async Task<TResult> RequestWithCancelAsync<TResult, TState>(TState state, CancellationToken cancellationToken, RequestOptions? overrideOptions, bool doRetry, Func<TState, CancellationToken, Task<TResult>> func)
// with Cancel, Timeout, Retry.
async Task<TResult> RequestWithAsync<TResult, TState>(TState state, CancellationToken cancellationToken, RequestOptions? overrideOptions, Func<TState, CancellationToken, Task<TResult>> func, Func<TResult, bool>? shouldRetry)
{
var retriesRemaining = !doRetry ? 0 : overrideOptions?.MaxRetries ?? MaxRetries;
var retriesRemaining = (shouldRetry == null) ? 0 : (overrideOptions?.MaxRetries ?? MaxRetries);
var timeout = overrideOptions?.Timeout ?? Timeout;
RETRY:
using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
Expand All @@ -150,7 +179,26 @@ async Task<TResult> RequestWithCancelAsync<TResult, TState>(TState state, Cancel
{
try
{
return await func(state, cts.Token).ConfigureAwait(false);
var result = await func(state, cts.Token).ConfigureAwait(false);
if (shouldRetry != null)
{
if (shouldRetry(result))
{
if (retriesRemaining > 0)
{
#if NETSTANDARD2_1
var rand = random;
#else
var rand = Random.Shared;
#endif
var sleep = CalculateDefaultRetryTimeoutMillis(rand, retriesRemaining, MaxRetries);
await Task.Delay(TimeSpan.FromMilliseconds(sleep), cancellationToken).ConfigureAwait(false);
retriesRemaining--;
goto RETRY;
}
}
}
return result;
}
catch (OperationCanceledException ex) when (ex.CancellationToken == cts.Token)
{
Expand Down

0 comments on commit 26df93e

Please sign in to comment.