Skip to content

Commit

Permalink
feat(weave): Implement VertexAI integration (#2743)
Browse files Browse the repository at this point in the history
* add: vertexai autopatch integration

* fix: lint

* add: vertexai_on_finish for handing token count, pricing, and execution time

* add: tests for vertexai integration

* fix: lint

* add: patch for ImageGenerationModel.generate_images

* update: tests

* update: cassettes

* fix: lint

* add: cassettes for async cases

* update: tests

* update: docs

* fix: lint

* add: dictify support

* update: tests

* add: skips to vertexai tests
  • Loading branch information
soumik12345 authored Nov 26, 2024
1 parent 33a90eb commit b89825e
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 9 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ jobs:
'mistral1',
'notdiamond',
'openai',
'vertexai',
'scorers_tests',
'pandas-test',
]
Expand Down
25 changes: 16 additions & 9 deletions docs/docs/guides/integrations/google-gemini.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,28 @@ import os
import google.generativeai as genai
import weave

weave.init(project_name="google_ai_studio-test")
weave.init(project_name="google-ai-studio-test")

genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content("Write a story about an AI and magic")
```

Weave will also automatically capture traces for [Vertex APIs](https://cloud.google.com/vertexai/docs). To start tracking, calling `weave.init(project_name="<YOUR-WANDB-PROJECT-NAME>")` and use the library as normal.

```python
import vertexai
import weave
from vertexai.generative_models import GenerativeModel

weave.init(project_name="vertex-ai-test")
vertexai.init(project="<YOUR-VERTEXAIPROJECT-NAME>", location="<YOUR-VERTEXAI-PROJECT-LOCATION>")
model = GenerativeModel("gemini-1.5-flash-002")
response = model.generate_content(
"What's a good name for a flower shop specialising in selling dried flower bouquets?"
)
```

## Track your own ops

Wrapping a function with `@weave.op` starts capturing inputs, outputs and app logic so you can debug how data flows through your app. You can deeply nest ops and build a tree of functions that you want to track. This also starts automatically versioning code as you experiment to capture ad-hoc details that haven't been committed to git.
Expand Down Expand Up @@ -97,11 +112,3 @@ Given a weave reference to any `weave.Model` object, you can spin up a fastapi s
```shell
weave serve weave:///your_entity/project-name/YourModel:<hash>
```

## Vertex API

Full Weave support for the `Vertex AI SDK` python package is currently in development, however there is a way you can integrate Weave with the Vertex API.

Vertex API supports OpenAI SDK compatibility ([docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-gemini-using-openai-library)), and if this is a way you build your application, Weave will automatically track your LLM calls via our [OpenAI](/guides/integrations/openai) SDK integration.

\* Please note that some features may not fully work as Vertex API doesn't implement the full OpenAI SDK capabilities.
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def lint(session):
"mistral1",
"notdiamond",
"openai",
"vertexai",
"scorers_tests",
"pandas-test",
],
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ notdiamond = ["notdiamond>=0.3.21", "litellm<=1.49.1"]
openai = ["openai>=1.0.0"]
pandas-test = ["pandas>=2.2.3"]
modal = ["modal", "python-dotenv"]
vertexai = ["vertexai>=1.70.0"]
test = [
"nox",
"pytest>=8.2.0",
Expand Down
125 changes: 125 additions & 0 deletions tests/integrations/vertexai/vertexai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import pytest

from weave.integrations.integration_utilities import op_name_from_ref


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.skip_clickhouse_client
def test_content_generation(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")
model.generate_content("What is the capital of France?")

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content"
output = call.output
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower()
assert output["candidates"][0]["content"]["role"] == "model"
assert output["candidates"][0]["finish_reason"] == "STOP"
assert "gemini-1.5-flash" in output["model_version"]


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.skip_clickhouse_client
def test_content_generation_stream(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")
response = model.generate_content("What is the capital of France?", stream=True)
chunks = [chunk.text for chunk in response]
assert len(chunks) > 1

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content"
output = call.output
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower()
assert output["candidates"][0]["content"]["role"] == "model"


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.asyncio
@pytest.mark.skip_clickhouse_client
async def test_content_generation_async(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")
await model.generate_content_async("What is the capital of France?")

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content_async"
output = call.output
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower()
assert output["candidates"][0]["content"]["role"] == "model"
assert output["candidates"][0]["finish_reason"] == "STOP"
assert "gemini-1.5-flash" in output["model_version"]


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.asyncio
@pytest.mark.skip_clickhouse_client
async def test_content_generation_async_stream(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")

async def get_response():
chunks = []
async for chunk in await model.generate_content_async(
"What is the capital of France?", stream=True
):
if chunk.text:
chunks.append(chunk.text)
return chunks

await get_response()

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content_async"
output = call.output
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower()
assert output["candidates"][0]["content"]["role"] == "model"
Empty file.
128 changes: 128 additions & 0 deletions weave/integrations/vertexai/vertexai_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import importlib
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional

import weave
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.weave_client import Call

if TYPE_CHECKING:
from vertexai.generative_models import GenerationResponse


def vertexai_accumulator(
acc: Optional["GenerationResponse"], value: "GenerationResponse"
) -> "GenerationResponse":
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
from google.cloud.aiplatform_v1beta1.types import (
prediction_service as gapic_prediction_service_types,
)
from vertexai.generative_models import GenerationResponse

if acc is None:
return value

candidates = []
for i, value_candidate in enumerate(value.candidates):
accumulated_texts = []
for j, value_part in enumerate(value_candidate.content.parts):
accumulated_text = acc.candidates[i].content.parts[j].text + value_part.text
accumulated_texts.append(accumulated_text)
parts = [gapic_content_types.Part(text=text) for text in accumulated_texts]
content = gapic_content_types.Content(
role=value_candidate.content.role, parts=parts
)
candidate = gapic_content_types.Candidate(content=content)
candidates.append(candidate)
accumulated_response = gapic_prediction_service_types.GenerateContentResponse(
candidates=candidates
)
acc = GenerationResponse._from_gapic(accumulated_response)

acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count
acc.usage_metadata.candidates_token_count += (
value.usage_metadata.candidates_token_count
)
acc.usage_metadata.total_token_count += value.usage_metadata.total_token_count
return acc


def vertexai_on_finish(
call: Call, output: Any, exception: Optional[BaseException]
) -> None:
original_model_name = call.inputs["self"]._model_name
model_name = original_model_name.split("/")[-1]
usage = {model_name: {"requests": 1}}
summary_update = {"usage": usage}
if output:
usage[model_name].update(
{
"prompt_tokens": output.usage_metadata.prompt_token_count,
"completion_tokens": output.usage_metadata.candidates_token_count,
"total_tokens": output.usage_metadata.total_token_count,
}
)
if call.summary is not None:
call.summary.update(summary_update)


def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: vertexai_accumulator,
should_accumulate=lambda inputs: isinstance(inputs, dict)
and bool(inputs.get("stream")),
)

return wrapper


def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return await fn(*args, **kwargs)

return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op()(_fn_wrapper(fn))
op.name = name # type: ignore
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: vertexai_accumulator,
should_accumulate=lambda inputs: isinstance(inputs, dict)
and bool(inputs.get("stream")),
)

return wrapper


vertexai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content",
vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content_async",
vertexai_wrapper_async(
name="vertexai.GenerativeModel.generate_content_async"
),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.preview.vision_models"),
"ImageGenerationModel.generate_images",
vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"),
),
]
)
4 changes: 4 additions & 0 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def autopatch() -> None:
from weave.integrations.mistral import mistral_patcher
from weave.integrations.notdiamond.tracing import notdiamond_patcher
from weave.integrations.openai.openai_sdk import openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher

openai_patcher.attempt_patch()
mistral_patcher.attempt_patch()
Expand All @@ -35,6 +36,7 @@ def autopatch() -> None:
cohere_patcher.attempt_patch()
google_genai_patcher.attempt_patch()
notdiamond_patcher.attempt_patch()
vertexai_patcher.attempt_patch()


def reset_autopatch() -> None:
Expand All @@ -53,6 +55,7 @@ def reset_autopatch() -> None:
from weave.integrations.mistral import mistral_patcher
from weave.integrations.notdiamond.tracing import notdiamond_patcher
from weave.integrations.openai.openai_sdk import openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher

openai_patcher.undo_patch()
mistral_patcher.undo_patch()
Expand All @@ -67,3 +70,4 @@ def reset_autopatch() -> None:
cohere_patcher.undo_patch()
google_genai_patcher.undo_patch()
notdiamond_patcher.undo_patch()
vertexai_patcher.undo_patch()

0 comments on commit b89825e

Please sign in to comment.