Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make OPENAI_API_KEY not necessary for Azure OpenAI #137

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from textgrad import Variable, TextualGradientDescent, BlackboxLLM, sum
from textgrad.engine.base import EngineLM
from textgrad.engine.openai import ChatOpenAI
from textgrad.engine.openai import AzureChatOpenAI, ChatOpenAI
from textgrad.autograd import LLMCall, FormattedLLMCall

logging.disable(logging.CRITICAL)
Expand Down Expand Up @@ -247,4 +247,15 @@ def test_multimodal_from_url():
image_variable_2 = Variable(image_data,
role_description="image to answer a question about", requires_grad=False)

assert image_variable_2.value == image_variable.value
assert image_variable_2.value == image_variable.value

def test_azure_openai_engine():
if os.environ.get("OPENAI_API_KEY"):
os.environ.pop("OPENAI_API_KEY")

with pytest.raises(ValueError):
engine = AzureChatOpenAI()

os.environ['AZURE_OPENAI_API_KEY'] = "fake_key"
os.environ['AZURE_OPENAI_API_BASE'] = "fake_base"
engine = AzureChatOpenAI()
40 changes: 23 additions & 17 deletions textgrad/engine/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def __init__(
system_prompt: str=DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool=False,
base_url: str=None,
azure_openai: bool=False,
**kwargs):
"""
:param model_string:
:param system_prompt:
:param base_url: Used to support Ollama
:param azure_openai: Set to True if you use Azure OpenAI.
"""
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")
Expand All @@ -47,20 +49,21 @@ def __init__(
self.system_prompt = system_prompt
self.base_url = base_url

if not base_url:
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")

self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)
elif base_url and base_url == OLLAMA_BASE_URL:
self.client = OpenAI(
base_url=base_url,
api_key="ollama"
)
else:
raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.")
if not azure_openai:
if not base_url:
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")

self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)
elif base_url and base_url == OLLAMA_BASE_URL:
self.client = OpenAI(
base_url=base_url,
api_key="ollama"
)
else:
raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.")

self.model_string = model_string
self.is_multimodal = is_multimodal
Expand Down Expand Up @@ -184,11 +187,14 @@ def __init__(
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_azure_{model_string}.db") # Changed cache path to differentiate from OpenAI cache

super().__init__(cache_path=cache_path, system_prompt=system_prompt, **kwargs)
super().__init__(cache_path=cache_path,
system_prompt=system_prompt,
azure_openai=True,
**kwargs)

self.system_prompt = system_prompt
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2023-07-01-preview")
if os.getenv("AZURE_OPENAI_API_KEY") is None:
if (os.getenv("AZURE_OPENAI_API_KEY") is None) or (os.getenv("AZURE_OPENAI_API_BASE") is None):
raise ValueError("Please set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_BASE, and AZURE_OPENAI_API_VERSION environment variables if you'd like to use Azure OpenAI models.")

self.client = AzureOpenAI(
Expand All @@ -197,4 +203,4 @@ def __init__(
azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"),
azure_deployment=model_string,
)
self.model_string = model_string
self.model_string = model_string