From 4797f27ba28d2a64c0120f334f52daf56692b14d Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Wed, 18 Oct 2023 16:56:06 -0700 Subject: [PATCH] Add support for multiple azure deployments --- LowCodeLLM/src/openAIWrapper.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/LowCodeLLM/src/openAIWrapper.py b/LowCodeLLM/src/openAIWrapper.py index 59a3181d..cad07878 100644 --- a/LowCodeLLM/src/openAIWrapper.py +++ b/LowCodeLLM/src/openAIWrapper.py @@ -3,7 +3,7 @@ import os import openai - +from litellm import Router class OpenAIWrapper: def __init__(self, temperature): self.key = os.environ.get("OPENAIKEY") @@ -20,12 +20,20 @@ def __init__(self, temperature): self.use_azure = False if self.use_azure: - openai.api_type = "azure" - self.api_base = os.environ.get('API_BASE') - openai.api_base = self.api_base - self.api_version = os.environ.get('API_VERSION') - openai.api_version = self.api_version - self.engine = os.environ.get('MODEL') + + model_list = [{ # list of model deployments + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": f"azure/{os.environ.get('MODEL')}", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.environ.get('API_VERSION'), + "api_base": os.environ.get('API_BASE') + }, + "tpm": 240000, + "rpm": 1800 + }] + self.router = Router(model_list = model_list) + self.engine = os.environ.get("MODEL") else: self.chat_model_id = "gpt-3.5-turbo" @@ -40,8 +48,8 @@ def run(self, prompt): def _post_request_chat(self, messages): try: if self.use_azure: - response = openai.ChatCompletion.create( - engine=self.engine, + response = self.router.completion( + model=self.engine, messages=messages, temperature=self.temperature, max_tokens=self.max_tokens,