From e1d0c0d894d9f76306f59c6d41a954aebbbb2832 Mon Sep 17 00:00:00 2001 From: Kyle Goyette Date: Wed, 30 Oct 2024 12:39:45 -0700 Subject: [PATCH] use tmp as store path for model providers file --- .gitignore | 3 +-- .../model_providers/model_providers.py | 14 +++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 0f6040766e8..49e6ccc33e0 100644 --- a/.gitignore +++ b/.gitignore @@ -17,5 +17,4 @@ gha-creds-*.json .coverage .nox *.log -*/file::memory:?cache=shared -weave/trace_server/model_providers \ No newline at end of file +*/file::memory:?cache=shared \ No newline at end of file diff --git a/weave/trace_server/model_providers/model_providers.py b/weave/trace_server/model_providers/model_providers.py index f33f18ad224..3e7e8d01d76 100644 --- a/weave/trace_server/model_providers/model_providers.py +++ b/weave/trace_server/model_providers/model_providers.py @@ -5,7 +5,8 @@ import requests model_providers_url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" -MODEL_PROVIDERS_FILE = "model_providers.json" +MODEL_PROVIDERS_FILE = "/tmp/model_providers/model_providers.json" + PROVIDER_TO_API_KEY_NAME_MAP = { "anthropic": "ANTHROPIC_API_KEY", @@ -22,11 +23,10 @@ class LLMModelProviderInfo(TypedDict): def fetch_model_to_provider_info_map( - cached_file_name: str = MODEL_PROVIDERS_FILE, + cached_file_path: str = MODEL_PROVIDERS_FILE, ) -> Dict[str, LLMModelProviderInfo]: - full_path = os.path.join(os.path.dirname(__file__), cached_file_name) - if os.path.exists(full_path): - with open(full_path, "r") as f: + if os.path.exists(cached_file_path): + with open(cached_file_path, "r") as f: return json.load(f) try: req = requests.get(model_providers_url) @@ -43,7 +43,7 @@ def fetch_model_to_provider_info_map( providers[k] = LLMModelProviderInfo( litellm_provider=provider, api_key_name=api_key_name ) - - with open(full_path, "w") as f: + os.makedirs(os.path.dirname(cached_file_path), exist_ok=True) + with open(cached_file_path, "w") as f: json.dump(providers, f) return providers