Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AzureAIInference client , wrapper and test #3301

Closed
wants to merge 10 commits into from
228 changes: 228 additions & 0 deletions autogen/oai/github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""Create a Github LLM Client with Azure Fallback.

# Usage example:
if __name__ == "__main__":
config = {
"model": "gpt-4o",
"system_prompt": "You are a knowledgeable history teacher.",
"use_azure_fallback": True
}

wrapper = GithubWrapper(config_list=[config])

response = wrapper.create(messages=[{"role": "user", "content": "What is the capital of France?"}])
print(wrapper.message_retrieval(response)[0])

conversation = [
{"role": "user", "content": "Tell me about the French Revolution."},
{"role": "assistant", "content": "The French Revolution was a period of major social and political upheaval in France that began in 1789 with the Storming of the Bastille and ended in the late 1790s with the ascent of Napoleon Bonaparte."},
{"role": "user", "content": "What were the main causes?"}
]

response = wrapper.create(messages=conversation)
print(wrapper.message_retrieval(response)[0])
"""

from __future__ import annotations

import json
import logging
import os
import time
from typing import Any, Dict, List, Optional, Tuple, Union

import requests
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionUsage

from autogen.cache import Cache
from autogen.oai.client_utils import should_hide_tools, validate_parameter

logger = logging.getLogger(__name__)


class GithubClient:
"""GitHub LLM Client with Azure Fallback"""

SUPPORTED_MODELS = [
"AI21-Jamba-Instruct",
"cohere-command-r",
"cohere-command-r-plus",
"meta-llama-3-70b-instruct",
"meta-llama-3-8b-instruct",
"meta-llama-3.1-405b-instruct",
"meta-llama-3.1-70b-instruct",
"meta-llama-3.1-8b-instruct",
"mistral-large",
"mistral-large-2407",
"mistral-nemo",
"mistral-small",
"gpt-4o",
"gpt-4o-mini",
"phi-3-medium-instruct-128k",
"phi-3-medium-instruct-4k",
"phi-3-mini-instruct-128k",
"phi-3-mini-instruct-4k",
"phi-3-small-instruct-128k",
"phi-3-small-instruct-8k",
]

def __init__(self, **kwargs):
self.github_endpoint_url = "https://models.inference.ai.azure.com/chat/completions"
self.model = kwargs.get("model")
self.system_prompt = kwargs.get("system_prompt", "You are a helpful assistant.")
self.use_azure_fallback = kwargs.get("use_azure_fallback", True)
self.rate_limit_reset_time = 0
self.request_count = 0
self.max_requests_per_minute = 15
self.max_requests_per_day = 150

self.github_token = os.environ.get("GITHUB_TOKEN")
self.azure_api_key = os.environ.get("AZURE_API_KEY")

if not self.github_token:
raise ValueError("GITHUB_TOKEN environment variable is not set.")
if self.use_azure_fallback and not self.azure_api_key:
raise ValueError("AZURE_API_KEY environment variable is not set.")

if self.model.lower() not in [model.lower() for model in self.SUPPORTED_MODELS]:
raise ValueError(f"Model {self.model} is not supported. Please choose from {self.SUPPORTED_MODELS}")

def message_retrieval(self, response: ChatCompletion) -> List[str]:
"""Retrieve the messages from the response."""
return [choice.message.content for choice in response.choices]

def create(self, params: Dict[str, Any]) -> ChatCompletion:
"""Create a completion for a given config."""
messages = params.get("messages", [])

if "system" not in [m["role"] for m in messages]:
messages.insert(0, {"role": "system", "content": self.system_prompt})

data = {"messages": messages, "model": self.model, **params}

if self._check_rate_limit():
try:
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.github_token}"}

response = self._call_api(self.github_endpoint_url, headers, data)
self._increment_request_count()
return self._process_response(response)
except (requests.exceptions.RequestException, ValueError) as e:
logger.warning(f"GitHub API call failed: {str(e)}. Falling back to Azure.")

if self.use_azure_fallback:
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.azure_api_key}"}

response = self._call_api(self.github_endpoint_url, headers, data)
return self._process_response(response)
else:
raise ValueError("Rate limit reached and Azure fallback is disabled.")

def _check_rate_limit(self) -> bool:
"""Check if the rate limit has been reached."""
current_time = time.time()
if current_time < self.rate_limit_reset_time:
return False
if self.request_count >= self.max_requests_per_minute:
self.rate_limit_reset_time = current_time + 60
self.request_count = 0
return False
return True

def _increment_request_count(self):
"""Increment the request count."""
self.request_count += 1

def _call_api(self, endpoint_url: str, headers: Dict[str, str], data: Dict[str, Any]) -> Dict[str, Any]:
"""Make an API call to either GitHub or Azure."""
response = requests.post(endpoint_url, headers=headers, json=data)
response.raise_for_status()
return response.json()

def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion:
"""Process the API response and return a ChatCompletion object."""
choices = [
Choice(
index=i,
message=ChatCompletionMessage(role="assistant", content=choice["message"]["content"]),
finish_reason=choice.get("finish_reason"),
)
for i, choice in enumerate(response_data["choices"])
]

usage = CompletionUsage(
prompt_tokens=response_data["usage"]["prompt_tokens"],
completion_tokens=response_data["usage"]["completion_tokens"],
total_tokens=response_data["usage"]["total_tokens"],
)

return ChatCompletion(
id=response_data["id"],
model=response_data["model"],
created=response_data["created"],
object="chat.completion",
choices=choices,
usage=usage,
)

def cost(self, response: ChatCompletion) -> float:
"""Calculate the cost of the response."""
# Pass
return 0.0 # Placeholder

@staticmethod
def get_usage(response: ChatCompletion) -> Dict:
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost if hasattr(response, "cost") else 0,
"model": response.model,
}


class GithubWrapper:
"""Wrapper for GitHub LLM Client"""

def __init__(self, config_list: Optional[List[Dict[str, Any]]] = None, **kwargs):
self._clients = []
self._config_list = []

if config_list:
for config in config_list:
self._register_client(config)
self._config_list.append(config)
else:
self._register_client(kwargs)
self._config_list = [kwargs]

def _register_client(self, config: Dict[str, Any]):
client = GithubClient(**config)
self._clients.append(client)

def create(self, **params: Any) -> ChatCompletion:
"""Create a completion using available clients."""
for i, client in enumerate(self._clients):
try:
response = client.create(params)
response.config_id = i
return response
except Exception as e:
logger.warning(f"Error with client {i}: {str(e)}")
if i == len(self._clients) - 1:
raise

def message_retrieval(self, response: ChatCompletion) -> List[str]:
"""Retrieve messages from the response."""
return self._clients[response.config_id].message_retrieval(response)

def cost(self, response: ChatCompletion) -> float:
"""Calculate the cost of the response."""
return self._clients[response.config_id].cost(response)

@staticmethod
def get_usage(response: ChatCompletion) -> Dict:
"""Get usage information from the response."""
return GithubClient.get_usage(response)
131 changes: 131 additions & 0 deletions test/oai/test_githubllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from unittest.mock import MagicMock, patch

import pytest

from autogen.oai.github import GithubClient, GithubWrapper


@pytest.fixture
def github_client():
with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}):
return GithubClient(model="gpt-4o", system_prompt="Test prompt")


@pytest.fixture
def github_wrapper():
with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}):
config = {"model": "gpt-4o", "system_prompt": "Test prompt", "use_azure_fallback": True}
return GithubWrapper(config_list=[config])


def test_github_client_initialization(github_client):
assert github_client.model == "gpt-4o"
assert github_client.system_prompt == "Test prompt"
assert github_client.use_azure_fallback


def test_github_client_unsupported_model():
with pytest.raises(ValueError):
with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}):
GithubClient(model="unsupported-model")


@patch("requests.post")
def test_github_client_create(mock_post, github_client):
mock_response = MagicMock()
mock_response.json.return_value = {
"id": "test_id",
"model": "gpt-4o",
"created": 1234567890,
"choices": [{"message": {"content": "Test response"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
mock_post.return_value = mock_response

params = {"messages": [{"role": "user", "content": "Test message"}]}
response = github_client.create(params)

assert response.id == "test_id"
assert response.model == "gpt-4o"
assert len(response.choices) == 1
assert response.choices[0].message.content == "Test response"


def test_github_client_message_retrieval(github_client):
mock_response = MagicMock()
mock_response.choices = [
MagicMock(message=MagicMock(content="Response 1")),
MagicMock(message=MagicMock(content="Response 2")),
]

messages = github_client.message_retrieval(mock_response)
assert messages == ["Response 1", "Response 2"]


def test_github_client_cost(github_client):
mock_response = MagicMock()
cost = github_client.cost(mock_response)
assert cost == 0.0 # Assuming the placeholder implementation


def test_github_client_get_usage(github_client):
mock_response = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_response.usage.total_tokens = 30
mock_response.model = "gpt-4o"

usage = github_client.get_usage(mock_response)
assert usage["prompt_tokens"] == 10
assert usage["completion_tokens"] == 20
assert usage["total_tokens"] == 30
assert usage["model"] == "gpt-4o"


@patch("autogen.oai.github.GithubClient.create")
def test_github_wrapper_create(mock_create, github_wrapper):
mock_response = MagicMock()
mock_create.return_value = mock_response

params = {"messages": [{"role": "user", "content": "Test message"}]}
response = github_wrapper.create(**params)

assert response == mock_response
assert hasattr(response, "config_id")
mock_create.assert_called_once_with(params)


def test_github_wrapper_message_retrieval(github_wrapper):
mock_response = MagicMock()
mock_response.config_id = 0

with patch.object(github_wrapper._clients[0], "message_retrieval") as mock_retrieval:
mock_retrieval.return_value = ["Test message"]
messages = github_wrapper.message_retrieval(mock_response)

assert messages == ["Test message"]


def test_github_wrapper_cost(github_wrapper):
mock_response = MagicMock()
mock_response.config_id = 0

with patch.object(github_wrapper._clients[0], "cost") as mock_cost:
mock_cost.return_value = 0.05
cost = github_wrapper.cost(mock_response)

assert cost == 0.05


def test_github_wrapper_get_usage(github_wrapper):
mock_response = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_response.usage.total_tokens = 30
mock_response.model = "gpt-4o"

usage = github_wrapper.get_usage(mock_response)
assert usage["prompt_tokens"] == 10
assert usage["completion_tokens"] == 20
assert usage["total_tokens"] == 30
assert usage["model"] == "gpt-4o"
Loading