Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trying to make engine implementation independent from cache #76

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions textgrad/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import hashlib
import diskcache as dc
from abc import ABC, abstractmethod
from typing import Union, List
import json

class EngineLM(ABC):
system_prompt: str = "You are a helpful, creative, and smart assistant."
Expand Down Expand Up @@ -41,3 +43,71 @@ def __setstate__(self, state):
# Restore the cache after unpickling
self.__dict__.update(state)
self.cache = dc.Cache(self.cache_path)

import platformdirs
import os

class CachedLLM(CachedEngine, EngineLM):
def __init__(self, model_string, is_multimodal=False, do_cache=False):
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")

super().__init__(cache_path=cache_path)
self.model_string = model_string
self.is_multimodal = is_multimodal
self.do_cache = do_cache

def __call__(self, prompt, **kwargs):
return self.generate(prompt, **kwargs)

@abstractmethod
def _generate_from_single_prompt(self, prompt: str, system_prompt: str=None, **kwargs):
pass

@abstractmethod
def _generate_from_multiple_input(self, content: List[Union[str, bytes]], system_prompt: str=None, **kwargs):
pass

def single_prompt_generate(self, prompt: str, system_prompt: str=None, **kwargs):
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt

if self.do_cache:
cache_or_none = self._check_cache(sys_prompt_arg + prompt)
if cache_or_none is not None:
return cache_or_none

response = self._generate_from_single_prompt(prompt, system_prompt=sys_prompt_arg, **kwargs)

if self.do_cache:
self._save_cache(sys_prompt_arg + prompt, response)
return response

def multimodal_generate(self, content: List[Union[str, bytes]], system_prompt: str = None, **kwargs):

sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
key = "".join([str(k) for k in content])
if self.do_cache:
cache_key = sys_prompt_arg + key
cache_or_none = self._check_cache(cache_key)
if cache_or_none is not None:
return cache_or_none

response = self._generate_from_multiple_input(content, system_prompt=sys_prompt_arg, **kwargs)

if self.do_cache:
cache_key = sys_prompt_arg + key
self._save_cache(cache_key, response)

return response

def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str = None, **kwargs):
if isinstance(content, str):
return self.single_prompt_generate(content, system_prompt=system_prompt, **kwargs)

elif isinstance(content, list):
has_multimodal_input = any(isinstance(item, bytes) for item in content)
if has_multimodal_input and not self.is_multimodal:
raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.")

return self.multimodal_generate(content, system_prompt=system_prompt, **kwargs)

88 changes: 87 additions & 1 deletion textgrad/engine/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
)
from typing import List, Union

from .base import EngineLM, CachedEngine

from .base import EngineLM, CachedEngine, CachedLLM
from .engine_utils import get_image_type_from_bytes

# Default base URL for OLLAMA
Expand Down Expand Up @@ -158,6 +159,91 @@ def _generate_from_multiple_input(
self._save_cache(cache_key, response_text)
return response_text


class OpenAIWithCachedLLM(CachedLLM):
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."

def __init__(self, model_string, is_multimodal=False, system_prompt: str = DEFAULT_SYSTEM_PROMPT, do_cache=False):
super().__init__(model_string=model_string, is_multimodal=is_multimodal, do_cache=do_cache)
"""
:param model_string:
:param system_prompt:
:param base_url: Used to support Ollama
"""

self.system_prompt = system_prompt

if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")

self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)

def _generate_from_single_prompt(
self, prompt: str, system_prompt: str= None, temperature=0, max_tokens=2000, top_p=0.99
):

response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
frequency_penalty=0,
presence_penalty=0,
stop=None,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

response = response.choices[0].message.content
return response

def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]:
"""Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API.
"""
formatted_content = []
for item in content:
if isinstance(item, bytes):
# For now, bytes are assumed to be images
image_type = get_image_type_from_bytes(item)
base64_image = base64.b64encode(item).decode('utf-8')
formatted_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/{image_type};base64,{base64_image}"
}
})
elif isinstance(item, str):
formatted_content.append({
"type": "text",
"text": item
})
else:
raise ValueError(f"Unsupported input type: {type(item)}")
return formatted_content

def _generate_from_multiple_input(
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
):
formatted_content = self._format_content(content)

response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": formatted_content},
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

response_text = response.choices[0].message.content
return response_text

class AzureChatOpenAI(ChatOpenAI):
def __init__(
self,
Expand Down
Loading