Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): add two providers to backend #3242

Merged
merged 4 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,8 +1447,8 @@ def completions_create(
if not secret_name:
raise InvalidRequest(f"No secret name found for model {model_name}")
api_key = secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name)
isBedrock = model_info.get("litellm_provider") == "bedrock"
if not api_key and not isBedrock:
provider = model_info.get("litellm_provider")
if not api_key and provider != "bedrock" and provider != "bedrock_converse":
raise MissingLLMApiKeyError(
f"No API key {secret_name} found for model {model_name}",
api_key_name=secret_name,
Expand All @@ -1458,7 +1458,7 @@ def completions_create(
res = lite_llm_completion(
api_key,
req.inputs,
isBedrock,
provider,
)
end_time = datetime.datetime.now()

Expand Down
13 changes: 11 additions & 2 deletions weave/trace_server/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@
)
from weave.trace_server.secret_fetcher_context import _secret_fetcher_context

NOVA_MODELS = ("nova-pro-v1", "nova-lite-v1", "nova-micro-v1")


def lite_llm_completion(
api_key: str,
inputs: tsi.CompletionsCreateRequestInputs,
isBedrock: bool,
provider: Optional[str] = None,
) -> tsi.CompletionsCreateRes:
aws_access_key_id, aws_secret_access_key, aws_region_name = None, None, None
if isBedrock:
if provider == "bedrock" or provider == "bedrock_converse":
aws_access_key_id, aws_secret_access_key, aws_region_name = (
get_bedrock_credentials(inputs.model)
)
# Nova models need the region in the model name
if any(x in inputs.model for x in NOVA_MODELS) and aws_region_name:
aws_inference_region = aws_region_name.split("-")[0]
inputs.model = "bedrock/" + aws_inference_region + "." + inputs.model
# XAI models don't support response_format
elif provider == "xai":
inputs.response_format = None

import litellm

Expand Down
Loading
Loading