Skip to content

Commit

Permalink
use tmp as store path for model providers file
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleGoyette committed Oct 30, 2024
1 parent 84b70aa commit e1d0c0d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ gha-creds-*.json
.coverage
.nox
*.log
*/file::memory:?cache=shared
weave/trace_server/model_providers
*/file::memory:?cache=shared
14 changes: 7 additions & 7 deletions weave/trace_server/model_providers/model_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import requests

model_providers_url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
MODEL_PROVIDERS_FILE = "model_providers.json"
MODEL_PROVIDERS_FILE = "/tmp/model_providers/model_providers.json"


PROVIDER_TO_API_KEY_NAME_MAP = {
"anthropic": "ANTHROPIC_API_KEY",
Expand All @@ -22,11 +23,10 @@ class LLMModelProviderInfo(TypedDict):


def fetch_model_to_provider_info_map(
cached_file_name: str = MODEL_PROVIDERS_FILE,
cached_file_path: str = MODEL_PROVIDERS_FILE,
) -> Dict[str, LLMModelProviderInfo]:
full_path = os.path.join(os.path.dirname(__file__), cached_file_name)
if os.path.exists(full_path):
with open(full_path, "r") as f:
if os.path.exists(cached_file_path):
with open(cached_file_path, "r") as f:
return json.load(f)
try:
req = requests.get(model_providers_url)
Expand All @@ -43,7 +43,7 @@ def fetch_model_to_provider_info_map(
providers[k] = LLMModelProviderInfo(
litellm_provider=provider, api_key_name=api_key_name
)

with open(full_path, "w") as f:
os.makedirs(os.path.dirname(cached_file_path), exist_ok=True)
with open(cached_file_path, "w") as f:
json.dump(providers, f)
return providers

0 comments on commit e1d0c0d

Please sign in to comment.