Skip to content

Commit

Permalink
add: google ai integration [WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Aug 5, 2024
1 parent aeb84b1 commit b2c7b42
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ chromadb>=0.5.0 # LangChain
pysqlite3-binary==0.5.3 # LangChain
cohere>=5.5.8 # Cohere
groq>=0.9.0 # Groq
google-generativeai==0.7.2 # Google Generative AI
4 changes: 4 additions & 0 deletions weave/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def autopatch() -> None:
from .integrations.anthropic.anthropic_sdk import anthropic_patcher
from .integrations.cohere.cohere_sdk import cohere_patcher
from .integrations.dspy.dspy_sdk import dspy_patcher
from .integrations.gemini.gemini_sdk import gemini_patcher
from .integrations.groq.groq_sdk import groq_patcher
from .integrations.langchain.langchain import langchain_patcher
from .integrations.litellm.litellm import litellm_patcher
Expand All @@ -25,12 +26,14 @@ def autopatch() -> None:
groq_patcher.attempt_patch()
dspy_patcher.attempt_patch()
cohere_patcher.attempt_patch()
gemini_patcher.attempt_patch()


def reset_autopatch() -> None:
from .integrations.anthropic.anthropic_sdk import anthropic_patcher
from .integrations.cohere.cohere_sdk import cohere_patcher
from .integrations.dspy.dspy_sdk import dspy_patcher
from .integrations.gemini.gemini_sdk import gemini_patcher
from .integrations.groq.groq_sdk import groq_patcher
from .integrations.langchain.langchain import langchain_patcher
from .integrations.litellm.litellm import litellm_patcher
Expand All @@ -47,3 +50,4 @@ def reset_autopatch() -> None:
groq_patcher.undo_patch()
dspy_patcher.undo_patch()
cohere_patcher.undo_patch()
gemini_patcher.undo_patch()
Empty file.
58 changes: 58 additions & 0 deletions weave/integrations/gemini/gemini_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import importlib
from typing import Callable, Dict, List, Optional, Union

from google.ai.generativelanguage_v1beta.types.safety import SafetyRating
from google.generativeai.types.generation_types import GenerateContentResponse

import weave
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher


def gemini_accumulator(
acc: GenerateContentResponse, value: GenerateContentResponse
) -> GenerateContentResponse:
for candidate_idx in range(len(value.candidates)):
candidate = value.candidates[candidate_idx]
for part_idx in range(len(candidate.content.parts)):
acc.candidates[candidate_idx].content.parts[
part_idx
].text += candidate.content.parts[part_idx].text
if isinstance(acc.candidates[candidate_idx].safety_ratings[0], SafetyRating):
acc.candidates[candidate_idx].safety_ratings = [
value.candidates[candidate_idx].safety_ratings
]
else:
acc.candidates[candidate_idx].safety_ratings.append(
value.candidates[candidate_idx].safety_ratings
)
return acc


def should_use_accumulator(inputs: Dict) -> bool:
return isinstance(inputs, dict) and bool(inputs.get("stream"))


def gemini_wrapper(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
# return op
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: gemini_accumulator,
should_accumulate=should_use_accumulator,
)

return wrapper


gemini_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("google.generativeai"),
"GenerativeModel.generate_content",
gemini_wrapper(name="google.generativeai.GenerativeModel.generate_content"),
),
]
)

0 comments on commit b2c7b42

Please sign in to comment.