diff --git a/firerequests/main.py b/firerequests/main.py index aaa3ca6..2f9c8ca 100644 --- a/firerequests/main.py +++ b/firerequests/main.py @@ -287,7 +287,7 @@ def compare(self, url: str, filename: Optional[str] = None): except Exception as e: print(f"Error in compare: {e}") - async def call_openai(self, model: str, system_prompt: str, user_prompt: str) -> str: + def call_openai_sync(self, model: str, system_prompt: str, user_prompt: str) -> str: from openai import OpenAI client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) completion = client.chat.completions.create( @@ -299,7 +299,10 @@ async def call_openai(self, model: str, system_prompt: str, user_prompt: str) -> ) return completion.choices[0].message.content - async def call_google(self, model: str, system_prompt: str, user_prompt: str) -> str: + async def call_openai(self, model: str, system_prompt: str, user_prompt: str) -> str: + return await asyncio.to_thread(self.call_openai_sync, model, system_prompt, user_prompt) + + def call_google_sync(self, model: str, system_prompt: str, user_prompt: str) -> str: import google.generativeai as genai genai.configure(api_key=os.environ["GEMINI_API_KEY"]) @@ -321,6 +324,9 @@ async def call_google(self, model: str, system_prompt: str, user_prompt: str) -> response = chat_session.send_message(user_prompt) return response.text + async def call_google(self, model: str, system_prompt: str, user_prompt: str) -> str: + return await asyncio.to_thread(self.call_google_sync, model, system_prompt, user_prompt) + async def generate_batch( self, provider: str, model: str, system_prompt: str, user_prompts: List[str] ) -> List[str]: