Skip to content

Commit

Permalink
errors on Rate Limit headers (#5)
Browse files Browse the repository at this point in the history
* errors for failed header parse, add test
* tests
  • Loading branch information
WillMatthews authored Aug 14, 2024
1 parent f9a4226 commit 7f6b330
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 99 deletions.
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (h *httpHeader) Header() http.Header {
return http.Header(*h)
}

func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders {
func (h *httpHeader) GetRateLimitHeaders() (RateLimitHeaders, error) {
return newRateLimitHeaders(h.Header())
}

Expand Down
2 changes: 1 addition & 1 deletion internal/test/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (ts *ServerTest) RegisterHandler(path string, handler handler) {
ts.handlers[path] = handler
}

// AnthropicTestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
// AnthropicTestServer Creates a mocked Anthropic server which can pretend to handle requests during testing.
func (ts *ServerTest) AnthropicTestServer() *httptest.Server {
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("received request at path %q\n", r.URL.Path)
Expand Down
230 changes: 142 additions & 88 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var rateLimitHeaders = map[string]string{

func TestMessages(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/messages", handleMessagesEndpoint)
server.RegisterHandler("/v1/messages", handleMessagesEndpoint(rateLimitHeaders))

ts := server.AnthropicTestServer()
ts.Start()
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestMessages(t *testing.T) {

func TestMessagesTokenError(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/messages", handleMessagesEndpoint)
server.RegisterHandler("/v1/messages", handleMessagesEndpoint(rateLimitHeaders))

ts := server.AnthropicTestServer()
ts.Start()
Expand Down Expand Up @@ -93,7 +93,7 @@ func TestMessagesTokenError(t *testing.T) {

func TestMessagesVision(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/messages", handleMessagesEndpoint)
server.RegisterHandler("/v1/messages", handleMessagesEndpoint(rateLimitHeaders))

ts := server.AnthropicTestServer()
ts.Start()
Expand Down Expand Up @@ -141,7 +141,7 @@ func TestMessagesVision(t *testing.T) {

func TestMessagesToolUse(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/messages", handleMessagesEndpoint)
server.RegisterHandler("/v1/messages", handleMessagesEndpoint(rateLimitHeaders))

ts := server.AnthropicTestServer()
ts.Start()
Expand Down Expand Up @@ -231,112 +231,166 @@ func TestMessagesToolUse(t *testing.T) {
}

func TestMessagesRateLimitHeaders(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/messages", handleMessagesEndpoint)

ts := server.AnthropicTestServer()
ts.Start()
defer ts.Close()
t.Run("parses valid rate limit headers", func(t *testing.T) {

baseUrl := ts.URL + "/v1"
client := anthropic.NewClient(
test.GetTestToken(),
anthropic.WithBaseURL(baseUrl),
)
server := test.NewTestServer()
server.RegisterHandler("/v1/messages", handleMessagesEndpoint(rateLimitHeaders))

resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaude3Haiku20240307,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
ts := server.AnthropicTestServer()
ts.Start()
defer ts.Close()

baseUrl := ts.URL + "/v1"
client := anthropic.NewClient(
test.GetTestToken(),
anthropic.WithBaseURL(baseUrl),
)

resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaude3Haiku20240307,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
})
if err != nil {
t.Fatalf("CreateMessages error: %v", err)
}

rlHeaders, err := resp.GetRateLimitHeaders()
if err != nil {
t.Fatalf("GetRateLimitHeaders error: %v", err)
}

bs, err := json.Marshal(rlHeaders)
if err != nil {
t.Fatal(err)
}

var expectedHeaders = map[string]any{
"anthropic-ratelimit-requests-limit": 100,
"anthropic-ratelimit-requests-remaining": 99,
"anthropic-ratelimit-requests-reset": "2024-06-04T07:13:19Z",
"anthropic-ratelimit-tokens-limit": 10000,
"anthropic-ratelimit-tokens-remaining": 9900,
"anthropic-ratelimit-tokens-reset": "2024-06-04T07:13:19Z",
"retry-after": 100,
}

bs2, err := json.Marshal(expectedHeaders)
if err != nil {
t.Fatal(err)
}

if string(bs) != string(bs2) {
t.Fatalf("rate limit headers mismatch. got %s, want %s", string(bs), string(bs2))
}
})
if err != nil {
t.Fatalf("CreateMessages error: %v", err)
}

rateLimitHeaders := resp.GetRateLimitHeaders()
t.Run("returns error for missing rate limit headers", func(t *testing.T) {

bs, err := json.Marshal(rateLimitHeaders)
if err != nil {
t.Fatal(err)
}
invalidHeaders := map[string]string{}

bs2, err := json.Marshal(rateLimitHeaders)
if err != nil {
t.Fatal(err)
}
server := test.NewTestServer()
server.RegisterHandler("/v1/messages", handleMessagesEndpoint(invalidHeaders))

if string(bs) != string(bs2) {
t.Fatalf("rate limit headers mismatch. got %s, want %s", string(bs), string(bs2))
}
ts := server.AnthropicTestServer()
ts.Start()
defer ts.Close()

baseUrl := ts.URL + "/v1"
client := anthropic.NewClient(
test.GetTestToken(),
anthropic.WithBaseURL(baseUrl),
)

resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaude3Haiku20240307,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
})
if err != nil {
t.Fatalf("CreateMessages error: %v", err)
}

_, err = resp.GetRateLimitHeaders()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}

func handleMessagesEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// Allows for injection of custom rate limit headers in the response to test client parsing.
func handleMessagesEndpoint(headers map[string]string) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}

var messagesReq anthropic.MessagesRequest
if messagesReq, err = getMessagesRequest(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
var messagesReq anthropic.MessagesRequest
if messagesReq, err = getMessagesRequest(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}

var hasToolResult bool
var hasToolResult bool

for _, m := range messagesReq.Messages {
for _, c := range m.Content {
if c.Type == anthropic.MessagesContentTypeToolResult {
hasToolResult = true
break
for _, m := range messagesReq.Messages {
for _, c := range m.Content {
if c.Type == anthropic.MessagesContentTypeToolResult {
hasToolResult = true
break
}
}
}
}

res := anthropic.MessagesResponse{
Type: "completion",
ID: strconv.Itoa(int(time.Now().Unix())),
Role: anthropic.RoleAssistant,
Content: []anthropic.MessageContent{
anthropic.NewTextMessageContent("hello"),
},
StopReason: anthropic.MessagesStopReasonEndTurn,
Model: messagesReq.Model,
Usage: anthropic.MessagesUsage{
InputTokens: 10,
OutputTokens: 10,
},
}
res := anthropic.MessagesResponse{
Type: "completion",
ID: strconv.Itoa(int(time.Now().Unix())),
Role: anthropic.RoleAssistant,
Content: []anthropic.MessageContent{
anthropic.NewTextMessageContent("hello"),
},
StopReason: anthropic.MessagesStopReasonEndTurn,
Model: messagesReq.Model,
Usage: anthropic.MessagesUsage{
InputTokens: 10,
OutputTokens: 10,
},
}

if len(messagesReq.Tools) > 0 {
if hasToolResult {
res.Content = []anthropic.MessageContent{
anthropic.NewTextMessageContent("The current weather in San Francisco is 65 degrees Fahrenheit. It's a nice, moderate temperature typical of the San Francisco Bay Area climate."),
}
} else {
m := map[string]any{
"location": "San Francisco, CA",
"unit": "celsius",
}
bs, _ := json.Marshal(m)
res.Content = []anthropic.MessageContent{
anthropic.NewTextMessageContent("Okay, let me check the weather in San Francisco:"),
anthropic.NewToolUseMessageContent("toolu_01Ex86JyJAe8RSbFRCTM3pQo", "get_weather", bs),
if len(messagesReq.Tools) > 0 {
if hasToolResult {
res.Content = []anthropic.MessageContent{
anthropic.NewTextMessageContent("The current weather in San Francisco is 65 degrees Fahrenheit. It's a nice, moderate temperature typical of the San Francisco Bay Area climate."),
}
} else {
m := map[string]any{
"location": "San Francisco, CA",
"unit": "celsius",
}
bs, _ := json.Marshal(m)
res.Content = []anthropic.MessageContent{
anthropic.NewTextMessageContent("Okay, let me check the weather in San Francisco:"),
anthropic.NewToolUseMessageContent("toolu_01Ex86JyJAe8RSbFRCTM3pQo", "get_weather", bs),
}
res.StopReason = anthropic.MessagesStopReasonToolUse
}
res.StopReason = anthropic.MessagesStopReasonToolUse
}
}

resBytes, _ = json.Marshal(res)
for k, v := range rateLimitHeaders {
w.Header().Set(k, v)
resBytes, _ = json.Marshal(res)
for k, v := range headers {
w.Header().Set(k, v)
}
_, _ = w.Write(resBytes)
}
_, _ = w.Write(resBytes)
}

func getMessagesRequest(r *http.Request) (req anthropic.MessagesRequest, err error) {
Expand Down
34 changes: 25 additions & 9 deletions ratelimit_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,27 @@ type RateLimitHeaders struct {
RetryAfter int `json:"retry-after"`
}

func newRateLimitHeaders(h http.Header) RateLimitHeaders {
requestsLimit, _ := strconv.Atoi(h.Get("anthropic-ratelimit-requests-limit"))
requestsRemaining, _ := strconv.Atoi(h.Get("anthropic-ratelimit-requests-remaining"))
requestsReset, _ := time.Parse(time.RFC3339, h.Get("anthropic-ratelimit-requests-reset"))
func newRateLimitHeaders(h http.Header) (RateLimitHeaders, error) {
errs := []error{}

tokensLimit, _ := strconv.Atoi(h.Get("anthropic-ratelimit-tokens-limit"))
tokensRemaining, _ := strconv.Atoi(h.Get("anthropic-ratelimit-tokens-remaining"))
tokensReset, _ := time.Parse(time.RFC3339, h.Get("anthropic-ratelimit-tokens-reset"))
requestsLimit, err := strconv.Atoi(h.Get("anthropic-ratelimit-requests-limit"))
errs = append(errs, err)
requestsRemaining, err := strconv.Atoi(h.Get("anthropic-ratelimit-requests-remaining"))
errs = append(errs, err)
requestsReset, err := time.Parse(time.RFC3339, h.Get("anthropic-ratelimit-requests-reset"))
errs = append(errs, err)

retryAfter, _ := strconv.Atoi(h.Get("anthropic-ratelimit-retry-after"))
tokensLimit, err := strconv.Atoi(h.Get("anthropic-ratelimit-tokens-limit"))
errs = append(errs, err)
tokensRemaining, err := strconv.Atoi(h.Get("anthropic-ratelimit-tokens-remaining"))
errs = append(errs, err)
tokensReset, err := time.Parse(time.RFC3339, h.Get("anthropic-ratelimit-tokens-reset"))
errs = append(errs, err)

return RateLimitHeaders{
retryAfter, err := strconv.Atoi(h.Get("retry-after"))
errs = append(errs, err)

headers := RateLimitHeaders{
RequestsLimit: requestsLimit,
RequestsRemaining: requestsRemaining,
RequestsReset: requestsReset,
Expand All @@ -43,4 +52,11 @@ func newRateLimitHeaders(h http.Header) RateLimitHeaders {
TokensReset: tokensReset,
RetryAfter: retryAfter,
}

for _, e := range errs {
if e != nil {
return headers, e
}
}
return headers, nil
}

0 comments on commit 7f6b330

Please sign in to comment.