-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(weave): Implement VertexAI integration #2743
Merged
Merged
Changes from 9 commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
042293c
add: vertexai autopatch integration
soumik12345 8a228bb
Merge branch 'master' into feat/vertexai
soumik12345 42a257b
fix: lint
soumik12345 57e414c
add: vertexai_on_finish for handing token count, pricing, and executi…
soumik12345 fc651d3
add: tests for vertexai integration
soumik12345 46e438f
fix: lint
soumik12345 5cd3a90
Merge branch 'master' into feat/vertexai
soumik12345 a56f8d8
Merge branch 'master' into feat/vertexai
soumik12345 6bb2d76
add: patch for ImageGenerationModel.generate_images
soumik12345 e4414cf
Merge branch 'master' into feat/vertexai
soumik12345 33d0e7d
Merge branch 'master' into feat/vertexai
soumik12345 231ade1
update: tests
soumik12345 a40d4e1
update: cassettes
soumik12345 3f62601
fix: lint
soumik12345 9f8238b
Merge branch 'master' into feat/vertexai
soumik12345 21dc1d7
Merge branch 'master' into feat/vertexai
soumik12345 fe9d10a
add: cassettes for async cases
soumik12345 aad4297
Merge branch 'master' into feat/vertexai
soumik12345 fb81f43
Merge branch 'master' into feat/vertexai
soumik12345 c6a1274
update: tests
soumik12345 7ae4ff3
update: docs
soumik12345 65c4cb3
fix: lint
soumik12345 5f9c74a
Merge branch 'master' into feat/vertexai
soumik12345 3356f64
Merge branch 'master' into feat/vertexai
soumik12345 6b56bf8
Merge branch 'master' into feat/vertexai
soumik12345 90795f0
Merge branch 'master' into feat/vertexai
soumik12345 e137602
add: dictify support
soumik12345 a33ab2c
Merge branch 'master' into feat/vertexai
soumik12345 e851329
Merge branch 'master' into feat/vertexai
soumik12345 00b1657
update: tests
soumik12345 22b360f
Merge branch 'master' into feat/vertexai
soumik12345 b1763d6
Merge branch 'master' into feat/vertexai
soumik12345 f256ec4
add: skips to vertexai tests
soumik12345 a0b3a77
Merge branch 'master' into feat/vertexai
soumik12345 8a56bde
Merge branch 'master' into feat/vertexai
soumik12345 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -240,6 +240,7 @@ jobs: | |
'mistral1', | ||
'notdiamond', | ||
'openai', | ||
'vertexai', | ||
'pandas-test', | ||
] | ||
fail-fast: false | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ def lint(session): | |
"mistral1", | ||
"notdiamond", | ||
"openai", | ||
"vertexai", | ||
"pandas-test", | ||
], | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import pytest | ||
|
||
from weave.integrations.integration_utilities import op_name_from_ref | ||
|
||
|
||
@pytest.mark.retry(max_attempts=5) | ||
@pytest.mark.skip_clickhouse_client | ||
def test_content_generation(client): | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel | ||
|
||
vertexai.init(project="wandb-growth", location="us-central1") | ||
model = GenerativeModel("gemini-1.5-flash") | ||
model.generate_content("Explain how AI works in simple terms") | ||
|
||
calls = list(client.calls()) | ||
assert len(calls) == 1 | ||
|
||
call = calls[0] | ||
assert call.started_at < call.ended_at | ||
|
||
trace_name = op_name_from_ref(call.op_name) | ||
assert trace_name == "vertexai.GenerativeModel.generate_content" | ||
assert call.output is not None | ||
|
||
|
||
@pytest.mark.retry(max_attempts=5) | ||
@pytest.mark.skip_clickhouse_client | ||
def test_content_generation_stream(client): | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel | ||
|
||
vertexai.init(project="wandb-growth", location="us-central1") | ||
model = GenerativeModel("gemini-1.5-flash") | ||
response = model.generate_content( | ||
"Explain how AI works in simple terms", stream=True | ||
) | ||
chunks = [chunk.text for chunk in response] | ||
assert len(chunks) > 1 | ||
|
||
calls = list(client.calls()) | ||
assert len(calls) == 1 | ||
|
||
call = calls[0] | ||
assert call.started_at < call.ended_at | ||
|
||
trace_name = op_name_from_ref(call.op_name) | ||
assert trace_name == "vertexai.GenerativeModel.generate_content" | ||
assert call.output is not None | ||
soumik12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.retry(max_attempts=5) | ||
@pytest.mark.asyncio | ||
@pytest.mark.skip_clickhouse_client | ||
async def test_content_generation_async(client): | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel | ||
|
||
vertexai.init(project="wandb-growth", location="us-central1") | ||
model = GenerativeModel("gemini-1.5-flash") | ||
_ = await model.generate_content_async("Explain how AI works in simple terms") | ||
|
||
calls = list(client.calls()) | ||
assert len(calls) == 1 | ||
|
||
call = calls[0] | ||
assert call.started_at < call.ended_at | ||
|
||
trace_name = op_name_from_ref(call.op_name) | ||
assert trace_name == "vertexai.GenerativeModel.generate_content_async" | ||
assert call.output is not None | ||
soumik12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.retry(max_attempts=5) | ||
@pytest.mark.asyncio | ||
@pytest.mark.skip_clickhouse_client | ||
async def test_content_generation_async_stream(client): | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel | ||
|
||
vertexai.init(project="wandb-growth", location="us-central1") | ||
model = GenerativeModel("gemini-1.5-flash") | ||
|
||
async def get_response(): | ||
chunks = [] | ||
async for chunk in await model.generate_content_async( | ||
"Explain how AI works in simple terms", stream=True | ||
): | ||
if chunk.text: | ||
chunks.append(chunk.text) | ||
return chunks | ||
|
||
_ = await get_response() | ||
|
||
calls = list(client.calls()) | ||
assert len(calls) == 1 | ||
|
||
call = calls[0] | ||
assert call.started_at < call.ended_at | ||
|
||
trace_name = op_name_from_ref(call.op_name) | ||
assert trace_name == "vertexai.GenerativeModel.generate_content_async" | ||
assert call.output is not None | ||
soumik12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import importlib | ||
from functools import wraps | ||
from typing import TYPE_CHECKING, Any, Callable, Optional | ||
|
||
import weave | ||
from weave.trace.op_extensions.accumulator import add_accumulator | ||
from weave.trace.patcher import MultiPatcher, SymbolPatcher | ||
from weave.trace.weave_client import Call | ||
|
||
if TYPE_CHECKING: | ||
from vertexai.generative_models import GenerationResponse | ||
|
||
|
||
def vertexai_accumulator( | ||
acc: Optional["GenerationResponse"], value: "GenerationResponse" | ||
) -> "GenerationResponse": | ||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types | ||
from google.cloud.aiplatform_v1beta1.types import ( | ||
prediction_service as gapic_prediction_service_types, | ||
) | ||
from vertexai.generative_models import GenerationResponse | ||
|
||
if acc is None: | ||
return value | ||
|
||
candidates = [] | ||
for i, value_candidate in enumerate(value.candidates): | ||
accumulated_texts = [] | ||
for j, value_part in enumerate(value_candidate.content.parts): | ||
accumulated_text = acc.candidates[i].content.parts[j].text + value_part.text | ||
accumulated_texts.append(accumulated_text) | ||
parts = [gapic_content_types.Part(text=text) for text in accumulated_texts] | ||
content = gapic_content_types.Content( | ||
role=value_candidate.content.role, parts=parts | ||
) | ||
candidate = gapic_content_types.Candidate(content=content) | ||
candidates.append(candidate) | ||
accumulated_response = gapic_prediction_service_types.GenerateContentResponse( | ||
candidates=candidates | ||
) | ||
acc = GenerationResponse._from_gapic(accumulated_response) | ||
|
||
acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count | ||
acc.usage_metadata.candidates_token_count += ( | ||
value.usage_metadata.candidates_token_count | ||
) | ||
acc.usage_metadata.total_token_count += value.usage_metadata.total_token_count | ||
return acc | ||
|
||
|
||
def vertexai_on_finish( | ||
call: Call, output: Any, exception: Optional[BaseException] | ||
) -> None: | ||
original_model_name = call.inputs["self"]._model_name | ||
model_name = original_model_name.split("/")[-1] | ||
usage = {model_name: {"requests": 1}} | ||
summary_update = {"usage": usage} | ||
if output: | ||
usage[model_name].update( | ||
{ | ||
"prompt_tokens": output.usage_metadata.prompt_token_count, | ||
"completion_tokens": output.usage_metadata.candidates_token_count, | ||
"total_tokens": output.usage_metadata.total_token_count, | ||
} | ||
) | ||
if call.summary is not None: | ||
call.summary.update(summary_update) | ||
|
||
|
||
def vertexai_wrapper_sync(name: str) -> Callable[[Callable], Callable]: | ||
def wrapper(fn: Callable) -> Callable: | ||
op = weave.op()(fn) | ||
op.name = name # type: ignore | ||
op._set_on_finish_handler(vertexai_on_finish) | ||
return add_accumulator( | ||
op, # type: ignore | ||
make_accumulator=lambda inputs: vertexai_accumulator, | ||
should_accumulate=lambda inputs: isinstance(inputs, dict) | ||
and bool(inputs.get("stream")), | ||
) | ||
|
||
return wrapper | ||
|
||
|
||
def vertexai_wrapper_async(name: str) -> Callable[[Callable], Callable]: | ||
def wrapper(fn: Callable) -> Callable: | ||
def _fn_wrapper(fn: Callable) -> Callable: | ||
@wraps(fn) | ||
async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: | ||
return await fn(*args, **kwargs) | ||
|
||
return _async_wrapper | ||
|
||
"We need to do this so we can check if `stream` is used" | ||
op = weave.op()(_fn_wrapper(fn)) | ||
op.name = name # type: ignore | ||
op._set_on_finish_handler(vertexai_on_finish) | ||
return add_accumulator( | ||
op, # type: ignore | ||
make_accumulator=lambda inputs: vertexai_accumulator, | ||
should_accumulate=lambda inputs: isinstance(inputs, dict) | ||
and bool(inputs.get("stream")), | ||
) | ||
|
||
return wrapper | ||
|
||
|
||
vertexai_patcher = MultiPatcher( | ||
[ | ||
SymbolPatcher( | ||
lambda: importlib.import_module("vertexai.generative_models"), | ||
"GenerativeModel.generate_content", | ||
vertexai_wrapper_sync(name="vertexai.GenerativeModel.generate_content"), | ||
), | ||
SymbolPatcher( | ||
lambda: importlib.import_module("vertexai.generative_models"), | ||
"GenerativeModel.generate_content_async", | ||
vertexai_wrapper_async( | ||
name="vertexai.GenerativeModel.generate_content_async" | ||
), | ||
), | ||
SymbolPatcher( | ||
lambda: importlib.import_module("vertexai.preview.vision_models"), | ||
"ImageGenerationModel.generate_images", | ||
vertexai_wrapper_sync(name="vertexai.ImageGenerationModel.generate_images"), | ||
), | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make a more relevant assert using a mock? This is just checking if anything populated at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added cassettes for
test_content_generation
andtest_content_generation_stream
. However, I'm unable to generate them fortest_content_generation_async
andtest_content_generation_async_stream
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do cassettes work here? I thought this used the google grpc stuff under the hood?
In the case where they don't work, you (or gpt!) will need to mock out the request/response manually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andrewtruong Do you mean that we need to write the cassettes manually for those specific functions? Or is there another way?