-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(weave): Implement VertexAI integration (#2743)
* add: vertexai autopatch integration * fix: lint * add: vertexai_on_finish for handing token count, pricing, and execution time * add: tests for vertexai integration * fix: lint * add: patch for ImageGenerationModel.generate_images * update: tests * update: cassettes * fix: lint * add: cassettes for async cases * update: tests * update: docs * fix: lint * add: dictify support * update: tests * add: skips to vertexai tests
- Loading branch information
1 parent
33a90eb
commit b89825e
Showing
8 changed files
with
276 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -240,6 +240,7 @@ jobs: | |
'mistral1', | ||
'notdiamond', | ||
'openai', | ||
'vertexai', | ||
'scorers_tests', | ||
'pandas-test', | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import pytest | ||
|
||
from weave.integrations.integration_utilities import op_name_from_ref | ||
|
||
|
||
@pytest.mark.skip( | ||
reason="This test depends on a non-deterministic external service provider" | ||
) | ||
@pytest.mark.flaky(reruns=5, reruns_delay=2) | ||
@pytest.mark.skip_clickhouse_client | ||
def test_content_generation(client): | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel | ||
|
||
vertexai.init(project="wandb-growth", location="us-central1") | ||
model = GenerativeModel("gemini-1.5-flash") | ||
model.generate_content("What is the capital of France?") | ||
|
||
calls = list(client.calls()) | ||
assert len(calls) == 1 | ||
|
||
call = calls[0] | ||
assert call.started_at < call.ended_at | ||
|
||
trace_name = op_name_from_ref(call.op_name) | ||
assert trace_name == "vertexai.GenerativeModel.generate_content" | ||
output = call.output | ||
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower() | ||
assert output["candidates"][0]["content"]["role"] == "model" | ||
assert output["candidates"][0]["finish_reason"] == "STOP" | ||
assert "gemini-1.5-flash" in output["model_version"] | ||
|
||
|
||
@pytest.mark.skip( | ||
reason="This test depends on a non-deterministic external service provider" | ||
) | ||
@pytest.mark.flaky(reruns=5, reruns_delay=2) | ||
@pytest.mark.skip_clickhouse_client | ||
def test_content_generation_stream(client): | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel | ||
|
||
vertexai.init(project="wandb-growth", location="us-central1") | ||
model = GenerativeModel("gemini-1.5-flash") | ||
response = model.generate_content("What is the capital of France?", stream=True) | ||
chunks = [chunk.text for chunk in response] | ||
assert len(chunks) > 1 | ||
|
||
calls = list(client.calls()) | ||
assert len(calls) == 1 | ||
|
||
call = calls[0] | ||
assert call.started_at < call.ended_at | ||
|
||
trace_name = op_name_from_ref(call.op_name) | ||
assert trace_name == "vertexai.GenerativeModel.generate_content" | ||
output = call.output | ||
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower() | ||
assert output["candidates"][0]["content"]["role"] == "model" | ||
|
||
|
||
@pytest.mark.skip( | ||
reason="This test depends on a non-deterministic external service provider" | ||
) | ||
@pytest.mark.flaky(reruns=5, reruns_delay=2) | ||
@pytest.mark.asyncio | ||
@pytest.mark.skip_clickhouse_client | ||
async def test_content_generation_async(client): | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel | ||
|
||
vertexai.init(project="wandb-growth", location="us-central1") | ||
model = GenerativeModel("gemini-1.5-flash") | ||
await model.generate_content_async("What is the capital of France?") | ||
|
||
calls = list(client.calls()) | ||
assert len(calls) == 1 | ||
|
||
call = calls[0] | ||
assert call.started_at < call.ended_at | ||
|
||
trace_name = op_name_from_ref(call.op_name) | ||
assert trace_name == "vertexai.GenerativeModel.generate_content_async" | ||
output = call.output | ||
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower() | ||
assert output["candidates"][0]["content"]["role"] == "model" | ||
assert output["candidates"][0]["finish_reason"] == "STOP" | ||
assert "gemini-1.5-flash" in output["model_version"] | ||
|
||
|
||
@pytest.mark.skip( | ||
reason="This test depends on a non-deterministic external service provider" | ||
) | ||
@pytest.mark.flaky(reruns=5, reruns_delay=2) | ||
@pytest.mark.asyncio | ||
@pytest.mark.skip_clickhouse_client | ||
async def test_content_generation_async_stream(client): | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel | ||
|
||
vertexai.init(project="wandb-growth", location="us-central1") | ||
model = GenerativeModel("gemini-1.5-flash") | ||
|
||
async def get_response(): | ||
chunks = [] | ||
async for chunk in await model.generate_content_async( | ||
"What is the capital of France?", stream=True | ||
): | ||
if chunk.text: | ||
chunks.append(chunk.text) | ||
return chunks | ||
|
||
await get_response() | ||
|
||
calls = list(client.calls()) | ||
assert len(calls) == 1 | ||
|
||
call = calls[0] | ||
assert call.started_at < call.ended_at | ||
|
||
trace_name = op_name_from_ref(call.op_name) | ||
assert trace_name == "vertexai.GenerativeModel.generate_content_async" | ||
output = call.output | ||
assert "paris" in output["candidates"][0]["content"]["parts"][0]["text"].lower() | ||
assert output["candidates"][0]["content"]["role"] == "model" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import importlib | ||
from functools import wraps | ||
from typing import TYPE_CHECKING, Any, Callable, Optional | ||
|
||
import weave | ||
from weave.trace.op_extensions.accumulator import add_accumulator | ||
from weave.trace.patcher import MultiPatcher, SymbolPatcher | ||
from weave.trace.weave_client import Call | ||
|
||
if TYPE_CHECKING: | ||
from vertexai.generative_models import GenerationResponse | ||
|
||
|
||
def vertexai_accumulator( | ||
acc: Optional["GenerationResponse"], value: "GenerationResponse" | ||
) -> "GenerationResponse": | ||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types | ||
from google.cloud.aiplatform_v1beta1.types import ( | ||
prediction_service as gapic_prediction_service_types, | ||
) | ||
from vertexai.generative_models import GenerationResponse | ||
|
||
if acc is None: | ||
return value | ||
|
||
candidates = [] | ||
for i, value_candidate in enumerate(value.candidates): | ||
accumulated_texts = [] | ||
for j, value_part in enumerate(value_candidate.content.parts): | ||
accumulated_text = acc.candidates[i].content.parts[j].text + value_part.text | ||
accumulated_texts.append(accumulated_text) | ||
parts = [gapic_content_types.Part(text=text) for text in accumulated_texts] | ||
content = gapic_content_types.Content( | ||
role=value_candidate.content.role, parts=parts | ||
) | ||
candidate = gapic_content_types.Candidate(content=content) | ||
candidates.append(candidate) | ||
accumulated_response = gapic_prediction_service_types.GenerateContentResponse( | ||
candidates=candidates | ||
) | ||
acc = GenerationResponse._from_gapic(accumulated_response) | ||
|
||
acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count | ||
acc.usage_metadata.candidates_token_count += ( | ||
value.usage_metadata.candidates_token_count | ||
) | ||
acc.usage_metadata.total_token_count += value.usage_metadata.total_token_count | ||
return acc | ||
|
||
|
||
def vertexai_on_finish( | ||
call: Call, output: Any, exception: Optional[BaseException] | ||
) -> None: | ||
original_model_name = call.inputs["self"]._model_name | ||
model_name = original_model_name.split("/")[-1] | ||
usage = {model_name: {"requests": 1}} | ||
summary_update = {"usage": usage} | ||
if output: | ||
usage[model_name].update( | ||
{ | ||
"prompt_tokens": output.usage_metadata.prompt_token_count, | ||
"completion_tokens": output.usage_metadata.candidates_token_count, | ||
"total_tokens": output.usage_metadata.total_token_count, | ||
} | ||
) | ||
if call.summary is not None: | ||
call.summary.update(summary_update) | ||
|
||
|
||
def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]: | ||
def wrapper(fn: Callable) -> Callable: | ||
op = weave.op()(fn) | ||
op.name = name # type: ignore | ||
op._set_on_finish_handler(vertexai_on_finish) | ||
return add_accumulator( | ||
op, # type: ignore | ||
make_accumulator=lambda inputs: vertexai_accumulator, | ||
should_accumulate=lambda inputs: isinstance(inputs, dict) | ||
and bool(inputs.get("stream")), | ||
) | ||
|
||
return wrapper | ||
|
||
|
||
def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]: | ||
def wrapper(fn: Callable) -> Callable: | ||
def _fn_wrapper(fn: Callable) -> Callable: | ||
@wraps(fn) | ||
async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: | ||
return await fn(*args, **kwargs) | ||
|
||
return _async_wrapper | ||
|
||
"We need to do this so we can check if `stream` is used" | ||
op = weave.op()(_fn_wrapper(fn)) | ||
op.name = name # type: ignore | ||
op._set_on_finish_handler(vertexai_on_finish) | ||
return add_accumulator( | ||
op, # type: ignore | ||
make_accumulator=lambda inputs: vertexai_accumulator, | ||
should_accumulate=lambda inputs: isinstance(inputs, dict) | ||
and bool(inputs.get("stream")), | ||
) | ||
|
||
return wrapper | ||
|
||
|
||
vertexai_patcher = MultiPatcher( | ||
[ | ||
SymbolPatcher( | ||
lambda: importlib.import_module("vertexai.generative_models"), | ||
"GenerativeModel.generate_content", | ||
vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"), | ||
), | ||
SymbolPatcher( | ||
lambda: importlib.import_module("vertexai.generative_models"), | ||
"GenerativeModel.generate_content_async", | ||
vertexai_wrapper_async( | ||
name="vertexai.GenerativeModel.generate_content_async" | ||
), | ||
), | ||
SymbolPatcher( | ||
lambda: importlib.import_module("vertexai.preview.vision_models"), | ||
"ImageGenerationModel.generate_images", | ||
vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"), | ||
), | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters