diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 0605b0534df..295b6ef8922 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -240,6 +240,7 @@ jobs: 'mistral1', 'notdiamond', 'openai', + 'vertexai', 'scorers_tests', 'pandas-test', ] diff --git a/docs/docs/guides/integrations/google-gemini.md b/docs/docs/guides/integrations/google-gemini.md index 351fb1247e7..6afc2790b3d 100644 --- a/docs/docs/guides/integrations/google-gemini.md +++ b/docs/docs/guides/integrations/google-gemini.md @@ -16,13 +16,28 @@ import os import google.generativeai as genai import weave -weave.init(project_name="google_ai_studio-test") +weave.init(project_name="google-ai-studio-test") genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) model = genai.GenerativeModel("gemini-1.5-flash") response = model.generate_content("Write a story about an AI and magic") ``` +Weave will also automatically capture traces for [Vertex APIs](https://cloud.google.com/vertexai/docs). To start tracking, calling `weave.init(project_name="")` and use the library as normal. + +```python +import vertexai +import weave +from vertexai.generative_models import GenerativeModel + +weave.init(project_name="vertex-ai-test") +vertexai.init(project="", location="") +model = GenerativeModel("gemini-1.5-flash-002") +response = model.generate_content( + "What's a good name for a flower shop specialising in selling dried flower bouquets?" +) +``` + ## Track your own ops Wrapping a function with `@weave.op` starts capturing inputs, outputs and app logic so you can debug how data flows through your app. You can deeply nest ops and build a tree of functions that you want to track. This also starts automatically versioning code as you experiment to capture ad-hoc details that haven't been committed to git. @@ -97,11 +112,3 @@ Given a weave reference to any `weave.Model` object, you can spin up a fastapi s ```shell weave serve weave:///your_entity/project-name/YourModel: ``` - -## Vertex API - -Full Weave support for the `Vertex AI SDK` python package is currently in development, however there is a way you can integrate Weave with the Vertex API. - -Vertex API supports OpenAI SDK compatibility ([docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-gemini-using-openai-library)), and if this is a way you build your application, Weave will automatically track your LLM calls via our [OpenAI](/guides/integrations/openai) SDK integration. - -\* Please note that some features may not fully work as Vertex API doesn't implement the full OpenAI SDK capabilities. diff --git a/noxfile.py b/noxfile.py index ef12e4d00ed..dff1305f21f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -47,6 +47,7 @@ def lint(session): "mistral1", "notdiamond", "openai", + "vertexai", "scorers_tests", "pandas-test", ], diff --git a/pyproject.toml b/pyproject.toml index d292166dd59..617d082e988 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ notdiamond = ["notdiamond>=0.3.21", "litellm<=1.49.1"] openai = ["openai>=1.0.0"] pandas-test = ["pandas>=2.2.3"] modal = ["modal", "python-dotenv"] +vertexai = ["vertexai>=1.70.0"] test = [ "nox", "pytest>=8.2.0", diff --git a/tests/integrations/vertexai/vertexai_test.py b/tests/integrations/vertexai/vertexai_test.py new file mode 100644 index 00000000000..4af39e62e32 --- /dev/null +++ b/tests/integrations/vertexai/vertexai_test.py @@ -0,0 +1,125 @@ +import pytest + +from weave.integrations.integration_utilities import op_name_from_ref + + +@pytest.mark.skip( + reason="This test depends on a non-deterministic external service provider" +) +@pytest.mark.flaky(reruns=5, reruns_delay=2) +@pytest.mark.skip_clickhouse_client +def test_content_generation(client): + import vertexai + from vertexai.generative_models import GenerativeModel + + vertexai.init(project="wandb-growth", location="us-central1") + model = GenerativeModel("gemini-1.5-flash") + model.generate_content("What is the capital of France?") + + calls = list(client.calls()) + assert len(calls) == 1 + + call = calls[0] + assert call.started_at < call.ended_at + + trace_name = op_name_from_ref(call.op_name) + assert trace_name == "vertexai.GenerativeModel.generate_content" + output = call.output + assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower() + assert output["candidates"][0]["content"]["role"] == "model" + assert output["candidates"][0]["finish_reason"] == "STOP" + assert "gemini-1.5-flash" in output["model_version"] + + +@pytest.mark.skip( + reason="This test depends on a non-deterministic external service provider" +) +@pytest.mark.flaky(reruns=5, reruns_delay=2) +@pytest.mark.skip_clickhouse_client +def test_content_generation_stream(client): + import vertexai + from vertexai.generative_models import GenerativeModel + + vertexai.init(project="wandb-growth", location="us-central1") + model = GenerativeModel("gemini-1.5-flash") + response = model.generate_content("What is the capital of France?", stream=True) + chunks = [chunk.text for chunk in response] + assert len(chunks) > 1 + + calls = list(client.calls()) + assert len(calls) == 1 + + call = calls[0] + assert call.started_at < call.ended_at + + trace_name = op_name_from_ref(call.op_name) + assert trace_name == "vertexai.GenerativeModel.generate_content" + output = call.output + assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower() + assert output["candidates"][0]["content"]["role"] == "model" + + +@pytest.mark.skip( + reason="This test depends on a non-deterministic external service provider" +) +@pytest.mark.flaky(reruns=5, reruns_delay=2) +@pytest.mark.asyncio +@pytest.mark.skip_clickhouse_client +async def test_content_generation_async(client): + import vertexai + from vertexai.generative_models import GenerativeModel + + vertexai.init(project="wandb-growth", location="us-central1") + model = GenerativeModel("gemini-1.5-flash") + await model.generate_content_async("What is the capital of France?") + + calls = list(client.calls()) + assert len(calls) == 1 + + call = calls[0] + assert call.started_at < call.ended_at + + trace_name = op_name_from_ref(call.op_name) + assert trace_name == "vertexai.GenerativeModel.generate_content_async" + output = call.output + assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower() + assert output["candidates"][0]["content"]["role"] == "model" + assert output["candidates"][0]["finish_reason"] == "STOP" + assert "gemini-1.5-flash" in output["model_version"] + + +@pytest.mark.skip( + reason="This test depends on a non-deterministic external service provider" +) +@pytest.mark.flaky(reruns=5, reruns_delay=2) +@pytest.mark.asyncio +@pytest.mark.skip_clickhouse_client +async def test_content_generation_async_stream(client): + import vertexai + from vertexai.generative_models import GenerativeModel + + vertexai.init(project="wandb-growth", location="us-central1") + model = GenerativeModel("gemini-1.5-flash") + + async def get_response(): + chunks = [] + async for chunk in await model.generate_content_async( + "What is the capital of France?", stream=True + ): + if chunk.text: + chunks.append(chunk.text) + return chunks + + await get_response() + + calls = list(client.calls()) + assert len(calls) == 1 + + call = calls[0] + assert call.started_at < call.ended_at + + trace_name = op_name_from_ref(call.op_name) + assert trace_name == "vertexai.GenerativeModel.generate_content_async" + output = call.output + assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower() + assert output["candidates"][0]["content"]["role"] == "model" diff --git a/weave/integrations/vertexai/__init__.py b/weave/integrations/vertexai/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py new file mode 100644 index 00000000000..348c9760ac4 --- /dev/null +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -0,0 +1,128 @@ +import importlib +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Optional + +import weave +from weave.trace.op_extensions.accumulator import add_accumulator +from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.weave_client import Call + +if TYPE_CHECKING: + from vertexai.generative_models import GenerationResponse + + +def vertexai_accumulator( + acc: Optional["GenerationResponse"], value: "GenerationResponse" +) -> "GenerationResponse": + from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types + from google.cloud.aiplatform_v1beta1.types import ( + prediction_service as gapic_prediction_service_types, + ) + from vertexai.generative_models import GenerationResponse + + if acc is None: + return value + + candidates = [] + for i, value_candidate in enumerate(value.candidates): + accumulated_texts = [] + for j, value_part in enumerate(value_candidate.content.parts): + accumulated_text = acc.candidates[i].content.parts[j].text + value_part.text + accumulated_texts.append(accumulated_text) + parts = [gapic_content_types.Part(text=text) for text in accumulated_texts] + content = gapic_content_types.Content( + role=value_candidate.content.role, parts=parts + ) + candidate = gapic_content_types.Candidate(content=content) + candidates.append(candidate) + accumulated_response = gapic_prediction_service_types.GenerateContentResponse( + candidates=candidates + ) + acc = GenerationResponse._from_gapic(accumulated_response) + + acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count + acc.usage_metadata.candidates_token_count += ( + value.usage_metadata.candidates_token_count + ) + acc.usage_metadata.total_token_count += value.usage_metadata.total_token_count + return acc + + +def vertexai_on_finish( + call: Call, output: Any, exception: Optional[BaseException] +) -> None: + original_model_name = call.inputs["self"]._model_name + model_name = original_model_name.split("/")[-1] + usage = {model_name: {"requests": 1}} + summary_update = {"usage": usage} + if output: + usage[model_name].update( + { + "prompt_tokens": output.usage_metadata.prompt_token_count, + "completion_tokens": output.usage_metadata.candidates_token_count, + "total_tokens": output.usage_metadata.total_token_count, + } + ) + if call.summary is not None: + call.summary.update(summary_update) + + +def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]: + def wrapper(fn: Callable) -> Callable: + op = weave.op()(fn) + op.name = name # type: ignore + op._set_on_finish_handler(vertexai_on_finish) + return add_accumulator( + op, # type: ignore + make_accumulator=lambda inputs: vertexai_accumulator, + should_accumulate=lambda inputs: isinstance(inputs, dict) + and bool(inputs.get("stream")), + ) + + return wrapper + + +def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]: + def wrapper(fn: Callable) -> Callable: + def _fn_wrapper(fn: Callable) -> Callable: + @wraps(fn) + async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: + return await fn(*args, **kwargs) + + return _async_wrapper + + "We need to do this so we can check if `stream` is used" + op = weave.op()(_fn_wrapper(fn)) + op.name = name # type: ignore + op._set_on_finish_handler(vertexai_on_finish) + return add_accumulator( + op, # type: ignore + make_accumulator=lambda inputs: vertexai_accumulator, + should_accumulate=lambda inputs: isinstance(inputs, dict) + and bool(inputs.get("stream")), + ) + + return wrapper + + +vertexai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content", + vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.generative_models"), + "GenerativeModel.generate_content_async", + vertexai_wrapper_async( + name="vertexai.GenerativeModel.generate_content_async" + ), + ), + SymbolPatcher( + lambda: importlib.import_module("vertexai.preview.vision_models"), + "ImageGenerationModel.generate_images", + vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"), + ), + ] +) diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index de37d951032..3a5dca14556 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -21,6 +21,7 @@ def autopatch() -> None: from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher openai_patcher.attempt_patch() mistral_patcher.attempt_patch() @@ -35,6 +36,7 @@ def autopatch() -> None: cohere_patcher.attempt_patch() google_genai_patcher.attempt_patch() notdiamond_patcher.attempt_patch() + vertexai_patcher.attempt_patch() def reset_autopatch() -> None: @@ -53,6 +55,7 @@ def reset_autopatch() -> None: from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher openai_patcher.undo_patch() mistral_patcher.undo_patch() @@ -67,3 +70,4 @@ def reset_autopatch() -> None: cohere_patcher.undo_patch() google_genai_patcher.undo_patch() notdiamond_patcher.undo_patch() + vertexai_patcher.undo_patch()