Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rishiraj authored Nov 6, 2024
1 parent dc8258c commit a8a652c
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions firerequests/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def compare(self, url: str, filename: Optional[str] = None):
except Exception as e:
print(f"Error in compare: {e}")

async def call_openai(self, model: str, system_prompt: str, user_prompt: str) -> str:
def call_openai_sync(self, model: str, system_prompt: str, user_prompt: str) -> str:
from openai import OpenAI
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
completion = client.chat.completions.create(
Expand All @@ -299,7 +299,10 @@ async def call_openai(self, model: str, system_prompt: str, user_prompt: str) ->
)
return completion.choices[0].message.content

async def call_google(self, model: str, system_prompt: str, user_prompt: str) -> str:
async def call_openai(self, model: str, system_prompt: str, user_prompt: str) -> str:
return await asyncio.to_thread(self.call_openai_sync, model, system_prompt, user_prompt)

def call_google_sync(self, model: str, system_prompt: str, user_prompt: str) -> str:
import google.generativeai as genai
genai.configure(api_key=os.environ["GEMINI_API_KEY"])

Expand All @@ -321,6 +324,9 @@ async def call_google(self, model: str, system_prompt: str, user_prompt: str) ->
response = chat_session.send_message(user_prompt)
return response.text

async def call_google(self, model: str, system_prompt: str, user_prompt: str) -> str:
return await asyncio.to_thread(self.call_google_sync, model, system_prompt, user_prompt)

async def generate_batch(
self, provider: str, model: str, system_prompt: str, user_prompts: List[str]
) -> List[str]:
Expand Down

0 comments on commit a8a652c

Please sign in to comment.