diff --git a/pyproject.toml b/pyproject.toml index 0c3405b..50e0f84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ tokencost = ["model_prices.json"] [project] name = "tokencost" -version = "0.1.12" +version = "0.1.13" authors = [ { name = "Trisha Pan", email = "trishaepan@gmail.com" }, { name = "Alex Reibman", email = "areibman@gmail.com" }, diff --git a/tests/test_costs.py b/tests/test_costs.py index d27fee1..8b3f5f6 100644 --- a/tests/test_costs.py +++ b/tests/test_costs.py @@ -45,6 +45,8 @@ ("gpt-4-1106-preview", 15), ("gpt-4-vision-preview", 15), ("gpt-4o", 15), + ("azure/gpt-4o", 15), + ("claude-2.1", 4), ], ) def test_count_message_tokens(model, expected_output): @@ -71,6 +73,9 @@ def test_count_message_tokens(model, expected_output): ("gpt-4-1106-preview", 17), ("gpt-4-vision-preview", 17), ("gpt-4o", 17), + ("azure/gpt-4o", 17), + ("claude-2.1", 4), + ], ) def test_count_message_tokens_with_name(model, expected_output): @@ -110,7 +115,8 @@ def test_count_message_tokens_invalid_model(): ("gpt-4-1106-preview", 4), ("gpt-4-vision-preview", 4), ("text-embedding-ada-002", 4), - ("gpt-4o", 4) + ("gpt-4o", 4), + ("claude-2.1", 4) ], ) def test_count_string_tokens(model, expected_output): @@ -148,6 +154,9 @@ def test_count_string_invalid_model(): (MESSAGES, "gpt-4-0613", Decimal("0.00045")), (MESSAGES, "gpt-4-1106-preview", Decimal("0.00015")), (MESSAGES, "gpt-4-vision-preview", Decimal("0.00015")), + (MESSAGES, "gpt-4o", Decimal("0.000075")), + (MESSAGES, "azure/gpt-4o", Decimal("0.000075")), + (MESSAGES, "claude-2.1", Decimal("0.000032")), (STRING, "text-embedding-ada-002", Decimal("0.0000004")), ], ) @@ -182,6 +191,9 @@ def test_invalid_prompt_format(): (STRING, "gpt-4-0613", Decimal("0.00024")), (STRING, "gpt-4-1106-preview", Decimal("0.00012")), (STRING, "gpt-4-vision-preview", Decimal("0.00012")), + (STRING, "gpt-4o", Decimal("0.000060")), + (STRING, "azure/gpt-4o", Decimal("0.000060")), + (STRING, "claude-2.1", Decimal("0.000096")), (STRING, "text-embedding-ada-002", 0), ], ) diff --git a/tokencost/costs.py b/tokencost/costs.py index ce29d0d..3677d6f 100644 --- a/tokencost/costs.py +++ b/tokencost/costs.py @@ -41,7 +41,7 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int: """ model = model.lower() model = strip_ft_model_name(model) - + if "claude-" in model: """ Note that this is only accurate for older models, e.g. `claude-2.1`. @@ -49,8 +49,8 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int: instead you should rely on the `usage` property in the response for exact counts. """ prompt = "".join(message["content"] for message in messages) - return count_string_tokens(prompt,model) - + return count_string_tokens(prompt, model) + try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -116,12 +116,20 @@ def count_string_tokens(prompt: str, model: str) -> int: int: The number of tokens in the text string. """ model = model.lower() + + if "/" in model: + model = model.split("/")[-1] + if "claude-" in model: """ Note that this is only accurate for older models, e.g. `claude-2.1`. For newer models this can only be used as a _very_ rough estimate, instead you should rely on the `usage` property in the response for exact counts. """ + if "claude-3" in model: + logger.warning( + "Warning: Claude-3 models are not yet supported. Returning num tokens assuming claude-2.1." + ) client = anthropic.Client() token_count = client.count_tokens(prompt) return token_count