From e7bcd17c68b0d2b1b67e8f453b8cf0e676d92188 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 22 Apr 2023 10:35:50 -0400 Subject: [PATCH] Fix whitespace issue --- src/marvin/ai_functions/strings.py | 26 ++++++++++++++++++-- tests/llm_tests/ai_functions/test_strings.py | 6 ++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/marvin/ai_functions/strings.py b/src/marvin/ai_functions/strings.py index f6b7f514f..82e68f7c5 100644 --- a/src/marvin/ai_functions/strings.py +++ b/src/marvin/ai_functions/strings.py @@ -1,14 +1,36 @@ +import asyncio +import functools +import inspect + from marvin.ai_functions import ai_fn +def _strip_result(fn): + """ + A decorator that automatically strips whitespace from the result of + calling the function + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + result = fn(*args, **kwargs) + if inspect.iscoroutine(result): + result = asyncio.run(result) + return result.strip() + + return wrapper + + +@_strip_result @ai_fn def fix_capitalization(text: str) -> str: """ - Given `text`, fix any capitalization errors. Do not change the text in any - other way. + Given `text`, which represents complete sentences, fix any capitalization + errors. """ +@_strip_result @ai_fn def title_case(text: str) -> str: """ diff --git a/tests/llm_tests/ai_functions/test_strings.py b/tests/llm_tests/ai_functions/test_strings.py index a080bcc89..ef8dd3471 100644 --- a/tests/llm_tests/ai_functions/test_strings.py +++ b/tests/llm_tests/ai_functions/test_strings.py @@ -2,7 +2,7 @@ class TestFixCapitalization: - def test_fix_capitalization(self): + def test_fix_capitalization(self, gpt_4): result = string_fns.fix_capitalization("the european went over to canada, eh?") assert result == "The European went over to Canada, eh?" @@ -13,8 +13,6 @@ def test_title_case(self): assert result == "The European Went Over to Canada, Eh?" def test_short_prepositions_not_capitalized(self): - result = string_fns.title_case( - input="let me go to the store", - ) + result = string_fns.title_case("let me go to the store") assert result == "Let Me Go to the Store"