Skip to content

Commit

Permalink
chore(weave): add two providers to backend (#3242)
Browse files Browse the repository at this point in the history
* add two providers to backend

* lint

* pr comments

* use optional
  • Loading branch information
jwlee64 authored Dec 14, 2024
1 parent c964e3f commit c0b90b4
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
6 changes: 3 additions & 3 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,8 +1447,8 @@ def completions_create(
if not secret_name:
raise InvalidRequest(f"No secret name found for model {model_name}")
api_key = secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name)
isBedrock = model_info.get("litellm_provider") == "bedrock"
if not api_key and not isBedrock:
provider = model_info.get("litellm_provider")
if not api_key and provider != "bedrock" and provider != "bedrock_converse":
raise MissingLLMApiKeyError(
f"No API key {secret_name} found for model {model_name}",
api_key_name=secret_name,
Expand All @@ -1458,7 +1458,7 @@ def completions_create(
res = lite_llm_completion(
api_key,
req.inputs,
isBedrock,
provider,
)
end_time = datetime.datetime.now()

Expand Down
13 changes: 11 additions & 2 deletions weave/trace_server/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@
)
from weave.trace_server.secret_fetcher_context import _secret_fetcher_context

NOVA_MODELS = ("nova-pro-v1", "nova-lite-v1", "nova-micro-v1")


def lite_llm_completion(
api_key: str,
inputs: tsi.CompletionsCreateRequestInputs,
isBedrock: bool,
provider: Optional[str] = None,
) -> tsi.CompletionsCreateRes:
aws_access_key_id, aws_secret_access_key, aws_region_name = None, None, None
if isBedrock:
if provider == "bedrock" or provider == "bedrock_converse":
aws_access_key_id, aws_secret_access_key, aws_region_name = (
get_bedrock_credentials(inputs.model)
)
# Nova models need the region in the model name
if any(x in inputs.model for x in NOVA_MODELS) and aws_region_name:
aws_inference_region = aws_region_name.split("-")[0]
inputs.model = "bedrock/" + aws_inference_region + "." + inputs.model
# XAI models don't support response_format
elif provider == "xai":
inputs.response_format = None

import litellm

Expand Down
Loading

0 comments on commit c0b90b4

Please sign in to comment.