diff --git a/tests/test_costs.py b/tests/test_costs.py index ece4b29..d27fee1 100644 --- a/tests/test_costs.py +++ b/tests/test_costs.py @@ -44,6 +44,7 @@ ("gpt-4-32k-0314", 15), ("gpt-4-1106-preview", 15), ("gpt-4-vision-preview", 15), + ("gpt-4o", 15), ], ) def test_count_message_tokens(model, expected_output): @@ -69,6 +70,7 @@ def test_count_message_tokens(model, expected_output): ("gpt-4-32k-0314", 17), ("gpt-4-1106-preview", 17), ("gpt-4-vision-preview", 17), + ("gpt-4o", 17), ], ) def test_count_message_tokens_with_name(model, expected_output): @@ -108,6 +110,7 @@ def test_count_message_tokens_invalid_model(): ("gpt-4-1106-preview", 4), ("gpt-4-vision-preview", 4), ("text-embedding-ada-002", 4), + ("gpt-4o", 4) ], ) def test_count_string_tokens(model, expected_output): diff --git a/tokencost/costs.py b/tokencost/costs.py index 0b098ef..feb931a 100644 --- a/tokencost/costs.py +++ b/tokencost/costs.py @@ -51,6 +51,10 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int: "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4o", + "gpt-4o-2024-05-13", }: tokens_per_message = 3 tokens_per_name = 1 @@ -63,6 +67,10 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int: "gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613." ) return count_message_tokens(messages, model="gpt-3.5-turbo-0613") + elif "gpt-4o" in model: + print( + "Warning: gpt-4o may update over time. Returning num tokens assuming gpt-4o-2024-05-13.") + return count_message_tokens(messages, model="gpt-4o-2024-05-13") elif "gpt-4" in model: logger.warning( "gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."