Skip to content

Commit

Permalink
xai llm
Browse files Browse the repository at this point in the history
  • Loading branch information
jwlee64 committed Dec 13, 2024
1 parent 0525621 commit 2b083ac
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,12 @@ export const LLM_MAX_TOKENS = {
max_tokens: 4096,
supports_function_calling: true,
},

'xai/grok-beta': {
max_tokens: 131072,
provider: 'xai',
supports_function_calling: true,
},
};

export type LLMMaxTokensKey = keyof typeof LLM_MAX_TOKENS;
Expand All @@ -367,6 +373,7 @@ export const LLM_PROVIDERS = [
'gemini',
'groq',
'bedrock',
'xai',
];

export const LLM_PROVIDER_LABELS: Record<
Expand All @@ -378,4 +385,5 @@ export const LLM_PROVIDER_LABELS: Record<
gemini: 'Gemini',
groq: 'Groq',
bedrock: 'Bedrock',
xai: 'xAI',
};
6 changes: 3 additions & 3 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,8 +1447,8 @@ def completions_create(
if not secret_name:
raise InvalidRequest(f"No secret name found for model {model_name}")
api_key = secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name)
isBedrock = model_info.get("litellm_provider") == "bedrock"
if not api_key and not isBedrock:
provider = model_info.get("litellm_provider")
if not api_key and provider != "bedrock":
raise MissingLLMApiKeyError(
f"No API key {secret_name} found for model {model_name}",
api_key_name=secret_name,
Expand All @@ -1458,7 +1458,7 @@ def completions_create(
res = lite_llm_completion(
api_key,
req.inputs,
isBedrock,
provider,
)
end_time = datetime.datetime.now()

Expand Down
6 changes: 4 additions & 2 deletions weave/trace_server/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
def lite_llm_completion(
api_key: str,
inputs: tsi.CompletionsCreateRequestInputs,
isBedrock: bool,
provider: str | None = None,
) -> tsi.CompletionsCreateRes:
aws_access_key_id, aws_secret_access_key, aws_region_name = None, None, None
if isBedrock:
if provider == "bedrock":
aws_access_key_id, aws_secret_access_key, aws_region_name = (
get_bedrock_credentials(inputs.model)
)
elif provider == "xai":
inputs.response_format = None

import litellm

Expand Down
2 changes: 1 addition & 1 deletion weave/trace_server/model_providers/model_providers.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions weave/trace_server/model_providers/model_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"fireworks": "FIREWORKS_API_KEY",
"groq": "GEMMA_API_KEY",
"bedrock": "BEDROCK_API_KEY",
"xai": "XAI_API_KEY",
}


Expand Down

0 comments on commit 2b083ac

Please sign in to comment.