Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Implement VertexAI integration #2743

Merged
merged 35 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
042293c
add: vertexai autopatch integration
soumik12345 Oct 21, 2024
8a228bb
Merge branch 'master' into feat/vertexai
soumik12345 Oct 21, 2024
42a257b
fix: lint
soumik12345 Oct 21, 2024
57e414c
add: vertexai_on_finish for handing token count, pricing, and executi…
soumik12345 Oct 21, 2024
fc651d3
add: tests for vertexai integration
soumik12345 Oct 21, 2024
46e438f
fix: lint
soumik12345 Oct 21, 2024
5cd3a90
Merge branch 'master' into feat/vertexai
soumik12345 Oct 21, 2024
a56f8d8
Merge branch 'master' into feat/vertexai
soumik12345 Oct 28, 2024
6bb2d76
add: patch for ImageGenerationModel.generate_images
soumik12345 Oct 28, 2024
e4414cf
Merge branch 'master' into feat/vertexai
soumik12345 Oct 28, 2024
33d0e7d
Merge branch 'master' into feat/vertexai
soumik12345 Oct 28, 2024
231ade1
update: tests
soumik12345 Oct 28, 2024
a40d4e1
update: cassettes
soumik12345 Oct 28, 2024
3f62601
fix: lint
soumik12345 Oct 28, 2024
9f8238b
Merge branch 'master' into feat/vertexai
soumik12345 Oct 29, 2024
21dc1d7
Merge branch 'master' into feat/vertexai
soumik12345 Oct 30, 2024
fe9d10a
add: cassettes for async cases
soumik12345 Oct 30, 2024
aad4297
Merge branch 'master' into feat/vertexai
soumik12345 Nov 4, 2024
fb81f43
Merge branch 'master' into feat/vertexai
soumik12345 Nov 4, 2024
c6a1274
update: tests
soumik12345 Nov 4, 2024
7ae4ff3
update: docs
soumik12345 Nov 4, 2024
65c4cb3
fix: lint
soumik12345 Nov 4, 2024
5f9c74a
Merge branch 'master' into feat/vertexai
soumik12345 Nov 5, 2024
3356f64
Merge branch 'master' into feat/vertexai
soumik12345 Nov 6, 2024
6b56bf8
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
90795f0
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
e137602
add: dictify support
soumik12345 Nov 7, 2024
a33ab2c
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
e851329
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
00b1657
update: tests
soumik12345 Nov 7, 2024
22b360f
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
b1763d6
Merge branch 'master' into feat/vertexai
soumik12345 Nov 22, 2024
f256ec4
add: skips to vertexai tests
soumik12345 Nov 22, 2024
a0b3a77
Merge branch 'master' into feat/vertexai
soumik12345 Nov 25, 2024
8a56bde
Merge branch 'master' into feat/vertexai
soumik12345 Nov 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ jobs:
'mistral1',
'notdiamond',
'openai',
'vertexai',
'scorers_tests',
'pandas-test',
]
Expand Down
25 changes: 16 additions & 9 deletions docs/docs/guides/integrations/google-gemini.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,28 @@ import os
import google.generativeai as genai
import weave

weave.init(project_name="google_ai_studio-test")
weave.init(project_name="google-ai-studio-test")

genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content("Write a story about an AI and magic")
```

Weave will also automatically capture traces for [Vertex APIs](https://cloud.google.com/vertexai/docs). To start tracking, calling `weave.init(project_name="<YOUR-WANDB-PROJECT-NAME>")` and use the library as normal.

```python
import vertexai
import weave
from vertexai.generative_models import GenerativeModel

weave.init(project_name="vertex-ai-test")
vertexai.init(project="<YOUR-VERTEXAIPROJECT-NAME>", location="<YOUR-VERTEXAI-PROJECT-LOCATION>")
model = GenerativeModel("gemini-1.5-flash-002")
response = model.generate_content(
"What's a good name for a flower shop specialising in selling dried flower bouquets?"
)
```

## Track your own ops

Wrapping a function with `@weave.op` starts capturing inputs, outputs and app logic so you can debug how data flows through your app. You can deeply nest ops and build a tree of functions that you want to track. This also starts automatically versioning code as you experiment to capture ad-hoc details that haven't been committed to git.
Expand Down Expand Up @@ -97,11 +112,3 @@ Given a weave reference to any `weave.Model` object, you can spin up a fastapi s
```shell
weave serve weave:///your_entity/project-name/YourModel:<hash>
```

## Vertex API

Full Weave support for the `Vertex AI SDK` python package is currently in development, however there is a way you can integrate Weave with the Vertex API.

Vertex API supports OpenAI SDK compatibility ([docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-gemini-using-openai-library)), and if this is a way you build your application, Weave will automatically track your LLM calls via our [OpenAI](/guides/integrations/openai) SDK integration.

\* Please note that some features may not fully work as Vertex API doesn't implement the full OpenAI SDK capabilities.
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def lint(session):
"mistral1",
"notdiamond",
"openai",
"vertexai",
"scorers_tests",
"pandas-test",
],
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ notdiamond = ["notdiamond>=0.3.21", "litellm<=1.49.1"]
openai = ["openai>=1.0.0"]
pandas-test = ["pandas>=2.2.3"]
modal = ["modal", "python-dotenv"]
vertexai = ["vertexai>=1.70.0"]
test = [
"nox",
"pytest>=8.2.0",
Expand Down
125 changes: 125 additions & 0 deletions tests/integrations/vertexai/vertexai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import pytest

from weave.integrations.integration_utilities import op_name_from_ref


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.skip_clickhouse_client
def test_content_generation(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")
model.generate_content("What is the capital of France?")

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content"
output = call.output
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower()
assert output["candidates"][0]["content"]["role"] == "model"
assert output["candidates"][0]["finish_reason"] == "STOP"
assert "gemini-1.5-flash" in output["model_version"]


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.skip_clickhouse_client
def test_content_generation_stream(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")
response = model.generate_content("What is the capital of France?", stream=True)
chunks = [chunk.text for chunk in response]
assert len(chunks) > 1

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content"
output = call.output
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower()
assert output["candidates"][0]["content"]["role"] == "model"


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.asyncio
@pytest.mark.skip_clickhouse_client
async def test_content_generation_async(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")
await model.generate_content_async("What is the capital of France?")

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content_async"
output = call.output
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower()
assert output["candidates"][0]["content"]["role"] == "model"
assert output["candidates"][0]["finish_reason"] == "STOP"
assert "gemini-1.5-flash" in output["model_version"]


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.asyncio
@pytest.mark.skip_clickhouse_client
async def test_content_generation_async_stream(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")

async def get_response():
chunks = []
async for chunk in await model.generate_content_async(
"What is the capital of France?", stream=True
):
if chunk.text:
chunks.append(chunk.text)
return chunks

await get_response()

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content_async"
output = call.output
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower()
assert output["candidates"][0]["content"]["role"] == "model"
Empty file.
128 changes: 128 additions & 0 deletions weave/integrations/vertexai/vertexai_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import importlib
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional

import weave
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.weave_client import Call

if TYPE_CHECKING:
from vertexai.generative_models import GenerationResponse


def vertexai_accumulator(
acc: Optional["GenerationResponse"], value: "GenerationResponse"
) -> "GenerationResponse":
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
from google.cloud.aiplatform_v1beta1.types import (
prediction_service as gapic_prediction_service_types,
)
from vertexai.generative_models import GenerationResponse

if acc is None:
return value

candidates = []
for i, value_candidate in enumerate(value.candidates):
accumulated_texts = []
for j, value_part in enumerate(value_candidate.content.parts):
accumulated_text = acc.candidates[i].content.parts[j].text + value_part.text
accumulated_texts.append(accumulated_text)
parts = [gapic_content_types.Part(text=text) for text in accumulated_texts]
content = gapic_content_types.Content(
role=value_candidate.content.role, parts=parts
)
candidate = gapic_content_types.Candidate(content=content)
candidates.append(candidate)
accumulated_response = gapic_prediction_service_types.GenerateContentResponse(
candidates=candidates
)
acc = GenerationResponse._from_gapic(accumulated_response)

acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count
acc.usage_metadata.candidates_token_count += (
value.usage_metadata.candidates_token_count
)
acc.usage_metadata.total_token_count += value.usage_metadata.total_token_count
return acc


def vertexai_on_finish(
call: Call, output: Any, exception: Optional[BaseException]
) -> None:
original_model_name = call.inputs["self"]._model_name
model_name = original_model_name.split("/")[-1]
usage = {model_name: {"requests": 1}}
summary_update = {"usage": usage}
if output:
usage[model_name].update(
{
"prompt_tokens": output.usage_metadata.prompt_token_count,
"completion_tokens": output.usage_metadata.candidates_token_count,
"total_tokens": output.usage_metadata.total_token_count,
}
)
if call.summary is not None:
call.summary.update(summary_update)


def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: vertexai_accumulator,
should_accumulate=lambda inputs: isinstance(inputs, dict)
and bool(inputs.get("stream")),
)

return wrapper


def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return await fn(*args, **kwargs)

return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op()(_fn_wrapper(fn))
op.name = name # type: ignore
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: vertexai_accumulator,
should_accumulate=lambda inputs: isinstance(inputs, dict)
and bool(inputs.get("stream")),
)

return wrapper


vertexai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content",
vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content_async",
vertexai_wrapper_async(
name="vertexai.GenerativeModel.generate_content_async"
),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.preview.vision_models"),
"ImageGenerationModel.generate_images",
vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"),
),
]
)
4 changes: 4 additions & 0 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def autopatch() -> None:
from weave.integrations.mistral import mistral_patcher
from weave.integrations.notdiamond.tracing import notdiamond_patcher
from weave.integrations.openai.openai_sdk import openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher

openai_patcher.attempt_patch()
mistral_patcher.attempt_patch()
Expand All @@ -35,6 +36,7 @@ def autopatch() -> None:
cohere_patcher.attempt_patch()
google_genai_patcher.attempt_patch()
notdiamond_patcher.attempt_patch()
vertexai_patcher.attempt_patch()


def reset_autopatch() -> None:
Expand All @@ -53,6 +55,7 @@ def reset_autopatch() -> None:
from weave.integrations.mistral import mistral_patcher
from weave.integrations.notdiamond.tracing import notdiamond_patcher
from weave.integrations.openai.openai_sdk import openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher

openai_patcher.undo_patch()
mistral_patcher.undo_patch()
Expand All @@ -67,3 +70,4 @@ def reset_autopatch() -> None:
cohere_patcher.undo_patch()
google_genai_patcher.undo_patch()
notdiamond_patcher.undo_patch()
vertexai_patcher.undo_patch()
Loading