diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..a490b61 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "githubPullRequests.ignoredPullRequestBranches": [ + "main" + ] +} \ No newline at end of file diff --git a/textgrad/engine/__init__.py b/textgrad/engine/__init__.py index 3e0ee24..e525b93 100644 --- a/textgrad/engine/__init__.py +++ b/textgrad/engine/__init__.py @@ -31,7 +31,7 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM: if engine_name in __ENGINE_NAME_SHORTCUTS__: engine_name = __ENGINE_NAME_SHORTCUTS__[engine_name] - if "seed" in kwargs and "gpt-4" not in engine_name and "gpt-3.5" not in engine_name and "gpt-35" not in engine_name: + if "seed" in kwargs and "gpt" not in engine_name.lower() and "gpt-3.5" not in engine_name and "gpt-35" not in engine_name: raise ValueError(f"Seed is currently supported only for OpenAI engines, not {engine_name}") if engine_name.startswith("azure"): diff --git a/textgrad/engine/openai.py b/textgrad/engine/openai.py index 723f04a..e6a863d 100644 --- a/textgrad/engine/openai.py +++ b/textgrad/engine/openai.py @@ -59,6 +59,10 @@ def __init__( base_url=base_url, api_key="ollama" ) + elif base_url and "azure" in base_url.lower(): + # Skip client initialization for Azure + # Azure-specific initialization will be handled in AzureChatOpenAI + pass else: raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.") @@ -181,20 +185,28 @@ def __init__( Raises: ValueError: If the AZURE_OPENAI_API_KEY environment variable is not set. """ + + root = platformdirs.user_cache_dir("textgrad") cache_path = os.path.join(root, f"cache_azure_{model_string}.db") # Changed cache path to differentiate from OpenAI cache - super().__init__(cache_path=cache_path, system_prompt=system_prompt, **kwargs) + azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + if azure_endpoint is None: + raise ValueError("Please set the AZURE_OPENAI_ENDPOINT environment variable if you'd like to use Azure OpenAI models.") + + super().__init__(model_string=model_string, system_prompt=system_prompt, base_url=azure_endpoint, **kwargs) + + + - self.system_prompt = system_prompt api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2023-07-01-preview") if os.getenv("AZURE_OPENAI_API_KEY") is None: - raise ValueError("Please set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_BASE, and AZURE_OPENAI_API_VERSION environment variables if you'd like to use Azure OpenAI models.") + raise ValueError("Please set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_API_VERSION environment variables if you'd like to use Azure OpenAI models.") self.client = AzureOpenAI( api_version=api_version, api_key=os.getenv("AZURE_OPENAI_API_KEY"), - azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"), + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), azure_deployment=model_string, ) self.model_string = model_string