Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Implement VertexAI integration #2743

Merged
merged 35 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
042293c
add: vertexai autopatch integration
soumik12345 Oct 21, 2024
8a228bb
Merge branch 'master' into feat/vertexai
soumik12345 Oct 21, 2024
42a257b
fix: lint
soumik12345 Oct 21, 2024
57e414c
add: vertexai_on_finish for handing token count, pricing, and executi…
soumik12345 Oct 21, 2024
fc651d3
add: tests for vertexai integration
soumik12345 Oct 21, 2024
46e438f
fix: lint
soumik12345 Oct 21, 2024
5cd3a90
Merge branch 'master' into feat/vertexai
soumik12345 Oct 21, 2024
a56f8d8
Merge branch 'master' into feat/vertexai
soumik12345 Oct 28, 2024
6bb2d76
add: patch for ImageGenerationModel.generate_images
soumik12345 Oct 28, 2024
e4414cf
Merge branch 'master' into feat/vertexai
soumik12345 Oct 28, 2024
33d0e7d
Merge branch 'master' into feat/vertexai
soumik12345 Oct 28, 2024
231ade1
update: tests
soumik12345 Oct 28, 2024
a40d4e1
update: cassettes
soumik12345 Oct 28, 2024
3f62601
fix: lint
soumik12345 Oct 28, 2024
9f8238b
Merge branch 'master' into feat/vertexai
soumik12345 Oct 29, 2024
21dc1d7
Merge branch 'master' into feat/vertexai
soumik12345 Oct 30, 2024
fe9d10a
add: cassettes for async cases
soumik12345 Oct 30, 2024
aad4297
Merge branch 'master' into feat/vertexai
soumik12345 Nov 4, 2024
fb81f43
Merge branch 'master' into feat/vertexai
soumik12345 Nov 4, 2024
c6a1274
update: tests
soumik12345 Nov 4, 2024
7ae4ff3
update: docs
soumik12345 Nov 4, 2024
65c4cb3
fix: lint
soumik12345 Nov 4, 2024
5f9c74a
Merge branch 'master' into feat/vertexai
soumik12345 Nov 5, 2024
3356f64
Merge branch 'master' into feat/vertexai
soumik12345 Nov 6, 2024
6b56bf8
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
90795f0
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
e137602
add: dictify support
soumik12345 Nov 7, 2024
a33ab2c
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
e851329
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
00b1657
update: tests
soumik12345 Nov 7, 2024
22b360f
Merge branch 'master' into feat/vertexai
soumik12345 Nov 7, 2024
b1763d6
Merge branch 'master' into feat/vertexai
soumik12345 Nov 22, 2024
f256ec4
add: skips to vertexai tests
soumik12345 Nov 22, 2024
a0b3a77
Merge branch 'master' into feat/vertexai
soumik12345 Nov 25, 2024
8a56bde
Merge branch 'master' into feat/vertexai
soumik12345 Nov 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ jobs:
'mistral1',
'notdiamond',
'openai',
'vertexai',
'pandas-test',
]
fail-fast: false
Expand Down
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def lint(session):
"mistral1",
"notdiamond",
"openai",
"vertexai",
"pandas-test",
],
)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ notdiamond = ["notdiamond>=0.3.21", "litellm<=1.49.1"]
openai = ["openai>=1.0.0"]
pandas-test = ["pandas>=2.2.3"]
modal = ["modal", "python-dotenv"]
vertexai = ["vertexai>=1.70.0"]
test = [
"nox",
"pytest>=8.2.0",
Expand Down
103 changes: 103 additions & 0 deletions tests/integrations/vertexai/vertexai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest

from weave.integrations.integration_utilities import op_name_from_ref


@pytest.mark.retry(max_attempts=5)
@pytest.mark.skip_clickhouse_client
def test_content_generation(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")
model.generate_content("Explain how AI works in simple terms")

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content"
assert call.output is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make a more relevant assert using a mock? This is just checking if anything populated at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added cassettes for test_content_generation and test_content_generation_stream. However, I'm unable to generate them for test_content_generation_async and test_content_generation_async_stream.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do cassettes work here? I thought this used the google grpc stuff under the hood?

In the case where they don't work, you (or gpt!) will need to mock out the request/response manually

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewtruong Do you mean that we need to write the cassettes manually for those specific functions? Or is there another way?



@pytest.mark.retry(max_attempts=5)
@pytest.mark.skip_clickhouse_client
def test_content_generation_stream(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")
response = model.generate_content(
"Explain how AI works in simple terms", stream=True
)
chunks = [chunk.text for chunk in response]
assert len(chunks) > 1

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content"
assert call.output is not None
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.retry(max_attempts=5)
@pytest.mark.asyncio
@pytest.mark.skip_clickhouse_client
async def test_content_generation_async(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")
_ = await model.generate_content_async("Explain how AI works in simple terms")

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content_async"
assert call.output is not None
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.retry(max_attempts=5)
@pytest.mark.asyncio
@pytest.mark.skip_clickhouse_client
async def test_content_generation_async_stream(client):
import vertexai
from vertexai.generative_models import GenerativeModel

vertexai.init(project="wandb-growth", location="us-central1")
model = GenerativeModel("gemini-1.5-flash")

async def get_response():
chunks = []
async for chunk in await model.generate_content_async(
"Explain how AI works in simple terms", stream=True
):
if chunk.text:
chunks.append(chunk.text)
return chunks

_ = await get_response()

calls = list(client.calls())
assert len(calls) == 1

call = calls[0]
assert call.started_at < call.ended_at

trace_name = op_name_from_ref(call.op_name)
assert trace_name == "vertexai.GenerativeModel.generate_content_async"
assert call.output is not None
soumik12345 marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
128 changes: 128 additions & 0 deletions weave/integrations/vertexai/vertexai_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import importlib
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional

import weave
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.weave_client import Call

if TYPE_CHECKING:
from vertexai.generative_models import GenerationResponse


def vertexai_accumulator(
acc: Optional["GenerationResponse"], value: "GenerationResponse"
) -> "GenerationResponse":
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
from google.cloud.aiplatform_v1beta1.types import (
prediction_service as gapic_prediction_service_types,
)
from vertexai.generative_models import GenerationResponse

if acc is None:
return value

candidates = []
for i, value_candidate in enumerate(value.candidates):
accumulated_texts = []
for j, value_part in enumerate(value_candidate.content.parts):
accumulated_text = acc.candidates[i].content.parts[j].text + value_part.text
accumulated_texts.append(accumulated_text)
parts = [gapic_content_types.Part(text=text) for text in accumulated_texts]
content = gapic_content_types.Content(
role=value_candidate.content.role, parts=parts
)
candidate = gapic_content_types.Candidate(content=content)
candidates.append(candidate)
accumulated_response = gapic_prediction_service_types.GenerateContentResponse(
candidates=candidates
)
acc = GenerationResponse._from_gapic(accumulated_response)

acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count
acc.usage_metadata.candidates_token_count += (
value.usage_metadata.candidates_token_count
)
acc.usage_metadata.total_token_count += value.usage_metadata.total_token_count
return acc


def vertexai_on_finish(
call: Call, output: Any, exception: Optional[BaseException]
) -> None:
original_model_name = call.inputs["self"]._model_name
model_name = original_model_name.split("/")[-1]
usage = {model_name: {"requests": 1}}
summary_update = {"usage": usage}
if output:
usage[model_name].update(
{
"prompt_tokens": output.usage_metadata.prompt_token_count,
"completion_tokens": output.usage_metadata.candidates_token_count,
"total_tokens": output.usage_metadata.total_token_count,
}
)
if call.summary is not None:
call.summary.update(summary_update)


def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: vertexai_accumulator,
should_accumulate=lambda inputs: isinstance(inputs, dict)
and bool(inputs.get("stream")),
)

return wrapper


def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return await fn(*args, **kwargs)

return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op()(_fn_wrapper(fn))
op.name = name # type: ignore
op._set_on_finish_handler(vertexai_on_finish)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: vertexai_accumulator,
should_accumulate=lambda inputs: isinstance(inputs, dict)
and bool(inputs.get("stream")),
)

return wrapper


vertexai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content",
vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.generative_models"),
"GenerativeModel.generate_content_async",
vertexai_wrapper_async(
name="vertexai.GenerativeModel.generate_content_async"
),
),
SymbolPatcher(
lambda: importlib.import_module("vertexai.preview.vision_models"),
"ImageGenerationModel.generate_images",
vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"),
),
]
)
4 changes: 4 additions & 0 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def autopatch() -> None:
from weave.integrations.mistral import mistral_patcher
from weave.integrations.notdiamond.tracing import notdiamond_patcher
from weave.integrations.openai.openai_sdk import openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher

openai_patcher.attempt_patch()
mistral_patcher.attempt_patch()
Expand All @@ -35,6 +36,7 @@ def autopatch() -> None:
cohere_patcher.attempt_patch()
google_genai_patcher.attempt_patch()
notdiamond_patcher.attempt_patch()
vertexai_patcher.attempt_patch()


def reset_autopatch() -> None:
Expand All @@ -53,6 +55,7 @@ def reset_autopatch() -> None:
from weave.integrations.mistral import mistral_patcher
from weave.integrations.notdiamond.tracing import notdiamond_patcher
from weave.integrations.openai.openai_sdk import openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher

openai_patcher.undo_patch()
mistral_patcher.undo_patch()
Expand All @@ -67,3 +70,4 @@ def reset_autopatch() -> None:
cohere_patcher.undo_patch()
google_genai_patcher.undo_patch()
notdiamond_patcher.undo_patch()
vertexai_patcher.undo_patch()
Loading