From 68d83c5d6620dcc36e33292e153ab973438e5009 Mon Sep 17 00:00:00 2001 From: Reinder Vos de Wael Date: Fri, 10 Nov 2023 15:39:40 -0500 Subject: [PATCH] Add synonym/jeopardy/antonym endpoints (#6) * Add synonym/jeopardy/antonym endpoints * online postgres debug * online postgres debug * Fix faulty ID * Refactor text task endpoints and schemas --- poetry.lock | 20 ++++- pyproject.toml | 1 + src/linguaweb_api/microservices/openai.py | 24 ++++- src/linguaweb_api/microservices/sql.py | 2 +- src/linguaweb_api/routers/text/controller.py | 48 +++++----- src/linguaweb_api/routers/text/models.py | 48 ++++++++-- src/linguaweb_api/routers/text/schemas.py | 21 +++++ src/linguaweb_api/routers/text/views.py | 74 +++++++++++++++- tests/endpoint/conftest.py | 3 + tests/endpoint/test_text.py | 93 +++++++++++++++----- tests/unit/test_openai.py | 8 +- 11 files changed, 282 insertions(+), 60 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9f60884..8e10704 100644 --- a/poetry.lock +++ b/poetry.lock @@ -875,6 +875,24 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "pytest-cov" version = "4.1.0" @@ -1220,4 +1238,4 @@ docs = [] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "98b1f38d6ad609a2a0e3dc93806f2b091560f59a26155b309b806eced69f0c47" +content-hash = "bdf1af7c2b7dae53b0ef050e5a5048cc214e79e830a215100877c9ae030e2ddb" diff --git a/pyproject.toml b/pyproject.toml index 366d7b7..106b81d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ ruff = "^0.1.4" httpx = "^0.25.1" pytest-mock = "^3.12.0" pytest-dotenv = "^0.5.2" +pytest-asyncio = "^0.21.1" [tool.poetry.group.docs.dependencies] pdoc = "^14.1.0" diff --git a/src/linguaweb_api/microservices/openai.py b/src/linguaweb_api/microservices/openai.py index 6bbeb60..f673f63 100644 --- a/src/linguaweb_api/microservices/openai.py +++ b/src/linguaweb_api/microservices/openai.py @@ -1,4 +1,5 @@ """This module contains interactions with OpenAI models.""" +import enum import logging from typing import Literal, TypedDict @@ -13,6 +14,27 @@ logger = logging.getLogger(LOGGER_NAME) +class Prompts(str, enum.Enum): + """A class representing the prompts for the GPT model.""" + + WORD_DESCRIPTION = ( + "Return a brief definition for the word provided by the user without using the " + "word (or number, if relevant) in the definition." + ) + WORD_SYNONYMS = ( + "List synonyms for the following word without using the word (or " + "number, if relevant) at all as a comma separated list" + ) + WORD_ANTONYMS = ( + "List antonyms for the following word without using the word (or number, if " + "relevant) at all as a comma separated list" + ) + WORD_JEOPARDY = ( + "Return a very brief Jeopardy!-style description related to the following word " + "without using the word (or number, if relevant) at all" + ) + + class Message(TypedDict): """A message object.""" @@ -38,7 +60,7 @@ def __init__(self, model: str = "gpt-4-1106-preview") -> None: self.model = model self.client = openai.OpenAI(api_key=OPENAI_API_KEY.get_secret_value()) - def run(self, *, prompt: str, system_prompt: str) -> str: + async def run(self, *, prompt: str, system_prompt: str) -> str: """Runs the GPT model. Args: diff --git a/src/linguaweb_api/microservices/sql.py b/src/linguaweb_api/microservices/sql.py index bd10126..ddc07c3 100644 --- a/src/linguaweb_api/microservices/sql.py +++ b/src/linguaweb_api/microservices/sql.py @@ -32,7 +32,7 @@ def __init__(self) -> None: """ logger.debug("Initializing database.") db_url = self.get_db_url() - engine_args: dict[str, Any] = {"echo": True} + engine_args: dict[str, Any] = {} if ENVIRONMENT == "development": engine_args["connect_args"] = {"check_same_thread": False} engine_args["poolclass"] = pool.StaticPool diff --git a/src/linguaweb_api/routers/text/controller.py b/src/linguaweb_api/routers/text/controller.py index d134a13..ed7031d 100644 --- a/src/linguaweb_api/routers/text/controller.py +++ b/src/linguaweb_api/routers/text/controller.py @@ -1,4 +1,5 @@ """Business logic for the text router.""" +import asyncio import logging import fastapi @@ -15,15 +16,11 @@ logger = logging.getLogger(LOGGER_NAME) -async def get_word_description( - session: orm.Session, - gpt: openai.GPT, -) -> models.TextTask: +async def get_text_task(session: orm.Session) -> models.TextTask: """Returns the description of a random word. Args: session: The database session. - gpt: The GPT model to use. Returns: The description of the word. @@ -32,25 +29,32 @@ async def get_word_description( logger.debug("Checking description.") word = dictionary.get_random_word() - word_database = session.query(models.TextTask).filter_by(word=word).first() - if not word_database: - logger.debug("Word not found in database.") - word_database = models.TextTask(word=word) - session.add(word_database) - - if word_database.description: - logger.debug("Word description found in database.") - return word_database - - logger.debug("Word description not found in database.") - system_prompt = ( - "Return a brief definition for the word provided by the user without using the " - "word (or number, if relevant) in the definition." + if text_task := session.query(models.TextTask).filter_by(word=word).first(): + logger.debug("Text task already exists in database.") + return text_task + + logger.debug("Running GPT.") + gpt = openai.GPT() + gpt_calls = [ + gpt.run(prompt=word, system_prompt=openai.Prompts.WORD_DESCRIPTION), + gpt.run(prompt=word, system_prompt=openai.Prompts.WORD_SYNONYMS), + gpt.run(prompt=word, system_prompt=openai.Prompts.WORD_ANTONYMS), + gpt.run(prompt=word, system_prompt=openai.Prompts.WORD_JEOPARDY), + ] + results = await asyncio.gather(*gpt_calls) + + logger.debug("Creating new text task.") + new_text_task = models.TextTask( + word=word, + description=results[0], + synonyms=results[1], + antonyms=results[2], + jeopardy=results[3], ) - description = gpt.run(prompt=word, system_prompt=system_prompt) - word_database.description = description + session.add(new_text_task) session.commit() - return word_database + logger.debug(new_text_task.id) + return new_text_task async def check_word( diff --git a/src/linguaweb_api/routers/text/models.py b/src/linguaweb_api/routers/text/models.py index 35eba5a..e803b9f 100644 --- a/src/linguaweb_api/routers/text/models.py +++ b/src/linguaweb_api/routers/text/models.py @@ -1,17 +1,55 @@ """Models for the text router.""" +from typing import Any + import sqlalchemy -from sqlalchemy import orm +from sqlalchemy import orm, types from linguaweb_api.core import models +class CommaSeparatedList(types.TypeDecorator): + """A custom SQLAlchemy for comma separated lists.""" + + impl = sqlalchemy.String(1024) + + def process_bind_param(self, value: Any | None, dialect: Any) -> str | None: # noqa: ANN401, ARG002 + """Converts a list of strings to a comma separated string. + + Args: + value: The list of strings or a comma separated string. + dialect: The dialect. + + Returns: + str | None: The comma separated string. + """ + if value is None: + return None + if isinstance(value, list): + return ",".join(value) + return value + + def process_result_value(self, value: Any | None, dialect: Any) -> list[str] | None: # noqa: ARG002, ANN401 + """Converts a comma separated string to a list of strings. + + Args: + value: The comma separated string. + dialect: The dialect. + + Returns: + list[str]: The list of strings. + """ + if value is None: + return None + return value.split(",") + + class TextTask(models.BaseTable): """Table for text tasks.""" __tablename__ = "text_tasks" word: orm.Mapped[str] = orm.mapped_column(sqlalchemy.String(64), unique=True) - description: orm.Mapped[str] = orm.mapped_column( - sqlalchemy.String(1024), - nullable=True, - ) + description: orm.Mapped[str] = orm.mapped_column(sqlalchemy.String(1024)) + synonyms: orm.Mapped[str] = orm.mapped_column(CommaSeparatedList) + antonyms: orm.Mapped[str] = orm.mapped_column(CommaSeparatedList) + jeopardy: orm.Mapped[str] = orm.mapped_column(sqlalchemy.String(1024)) diff --git a/src/linguaweb_api/routers/text/schemas.py b/src/linguaweb_api/routers/text/schemas.py index 333bf1c..bca0217 100644 --- a/src/linguaweb_api/routers/text/schemas.py +++ b/src/linguaweb_api/routers/text/schemas.py @@ -9,6 +9,27 @@ class WordDescription(pydantic.BaseModel): description: str +class WordSynonyms(pydantic.BaseModel): + """Word synonym task.""" + + id: int + synonyms: list[str] + + +class WordAntonyms(pydantic.BaseModel): + """Word antonym task.""" + + id: int + antonyms: list[str] + + +class WordJeopardy(pydantic.BaseModel): + """Word jeopardy task.""" + + id: int + jeopardy: str + + class WordCheck(pydantic.BaseModel): """Schema for checking information about a word.""" diff --git a/src/linguaweb_api/routers/text/views.py b/src/linguaweb_api/routers/text/views.py index 774278c..2fbb058 100644 --- a/src/linguaweb_api/routers/text/views.py +++ b/src/linguaweb_api/routers/text/views.py @@ -6,7 +6,7 @@ from sqlalchemy import orm from linguaweb_api.core import config -from linguaweb_api.microservices import openai, sql +from linguaweb_api.microservices import sql from linguaweb_api.routers.text import controller, schemas settings = config.get_settings() @@ -26,7 +26,6 @@ ) async def get_word_description( session: orm.Session = fastapi.Depends(sql.get_session), - gpt: openai.GPT = fastapi.Depends(openai.GPT), ) -> schemas.WordDescription: """Returns the description of a random word. @@ -35,7 +34,72 @@ async def get_word_description( gpt: The GPT model to use. """ logger.debug("Getting word description.") - return await controller.get_word_description(session, gpt) + text_task = await controller.get_text_task(session) + logger.debug("Got word description.") + return text_task + + +@router.get( + "/synonyms", + response_model=schemas.WordSynonyms, + status_code=status.HTTP_200_OK, + summary="Returns synonyms of a random word.", + description="Returns synonyms of a random word.", +) +async def get_word_synonym( + session: orm.Session = fastapi.Depends(sql.get_session), +) -> schemas.WordSynonyms: + """Returns synonyms of a random word. + + Args: + session: The database session. + """ + logger.debug("Getting word synonym.") + text_task = await controller.get_text_task(session) + logger.debug("Got word synonym.") + return text_task + + +@router.get( + "/antonyms", + response_model=schemas.WordAntonyms, + status_code=status.HTTP_200_OK, + summary="Returns antonyms of a random word.", + description="Returns antonyms of a random word.", +) +async def get_word_antonym( + session: orm.Session = fastapi.Depends(sql.get_session), +) -> schemas.WordAntonyms: + """Returns antonyms of a random word. + + Args: + session: The database session. + """ + logger.debug("Getting word antonym.") + text_task = await controller.get_text_task(session) + logger.debug("Got word antonym.") + return text_task + + +@router.get( + "/jeopardy", + response_model=schemas.WordJeopardy, + status_code=status.HTTP_200_OK, + summary="Returns a jeopardy question of a random word.", + description="Returns a jeopardy question of a random word.", +) +async def get_word_jeopardy( + session: orm.Session = fastapi.Depends(sql.get_session), +) -> schemas.WordJeopardy: + """Returns a jeopardy question of a random word. + + Args: + session: The database session. + """ + logger.debug("Getting word jeopardy.") + text_task = await controller.get_text_task(session) + logger.debug("Got word jeopardy.") + return text_task @router.post( @@ -67,4 +131,6 @@ async def check_word( session: The database session. """ logger.debug("Checking word.") - return await controller.check_word(word_id, checks, session) + is_correct = await controller.check_word(word_id, checks, session) + logger.debug("Checked word.") + return is_correct diff --git a/tests/endpoint/conftest.py b/tests/endpoint/conftest.py index 0ca973c..dd0ca01 100644 --- a/tests/endpoint/conftest.py +++ b/tests/endpoint/conftest.py @@ -16,6 +16,9 @@ class Endpoints(str, enum.Enum): """Enum class that represents the available endpoints for the API.""" GET_DESCRIPTION = f"{API_ROOT}/text/description" + GET_SYNONYMS = f"{API_ROOT}/text/synonyms" + GET_ANTONYMS = f"{API_ROOT}/text/antonyms" + GET_JEOPARDY = f"{API_ROOT}/text/jeopardy" POST_CHECK_WORD = f"{API_ROOT}/text/check/{{word_id}}" GET_HEALTH = f"{API_ROOT}/health" diff --git a/tests/endpoint/test_text.py b/tests/endpoint/test_text.py index aa52101..ebe460e 100644 --- a/tests/endpoint/test_text.py +++ b/tests/endpoint/test_text.py @@ -1,4 +1,5 @@ """Tests for the text endpoints.""" +import pytest import pytest_mock from fastapi import status, testclient from sqlalchemy import orm @@ -7,56 +8,100 @@ from tests.endpoint import conftest -def test_get_description_entry_does_not_exist( +@pytest.fixture() +def text_task(session: orm.Session) -> text_model.TextTask: + """Inserts a text task into the database. + + Args: + session: The database session. + """ + word = text_model.TextTask( + word="test", + description="mock_description", + synonyms="mock_synonyms", + antonyms="mock_antonyms", + jeopardy="mock_jeopardy", + ) + session.add(word) + session.commit() + return word + + +@pytest.mark.parametrize( + "endpoint_type", + [ + "description", + "synonyms", + "antonyms", + "jeopardy", + ], +) +def test_get_text_entry_exist( mocker: pytest_mock.MockFixture, + endpoint_type: str, + text_task: text_model.TextTask, client: testclient.TestClient, endpoints: conftest.Endpoints, ) -> None: - """Tests the get health endpoint.""" + """Tests the get text task endpoints with an existing file.""" mocker.patch( - "linguaweb_api.microservices.openai.GPT.run", - return_value="mock_description", + "linguaweb_api.core.dictionary.get_random_word", + return_value=text_task.word, ) + endpoint = getattr(endpoints, f"GET_{endpoint_type.upper()}") - response = client.get(endpoints.GET_DESCRIPTION) + response = client.get(endpoint) assert response.status_code == status.HTTP_200_OK - assert response.json()["description"] == "mock_description" - assert isinstance(response.json()["id"], int) - - -def test_get_description_entry_exist_without_description( + assert response.json()["id"] == text_task.id + if endpoint_type in ("synonyms", "antonyms"): + assert response.json()[endpoint_type][0] == f"mock_{endpoint_type}" + else: + assert response.json()[endpoint_type] == f"mock_{endpoint_type}" + + +@pytest.mark.parametrize( + "endpoint_type", + [ + "description", + "synonyms", + "antonyms", + "jeopardy", + ], +) +def test_get_text_entry_does_not_exist( mocker: pytest_mock.MockFixture, + endpoint_type: str, client: testclient.TestClient, endpoints: conftest.Endpoints, ) -> None: - """Tests the get health endpoint.""" - mocker.patch( + """Tests the get text task endpoints with no existing file.""" + mock_gpt = mocker.patch( "linguaweb_api.microservices.openai.GPT.run", - return_value="mock_description", + return_value="mock", ) + expected_n_gpt_calls = 4 + endpoint = getattr(endpoints, f"GET_{endpoint_type.upper()}") - response = client.get(endpoints.GET_DESCRIPTION) + response = client.get(endpoint) + assert mock_gpt.call_count == expected_n_gpt_calls assert response.status_code == status.HTTP_200_OK - assert response.json()["description"] == "mock_description" assert isinstance(response.json()["id"], int) + if endpoint_type in ("synonyms", "antonyms"): + assert isinstance(response.json()[endpoint_type], list) + else: + assert isinstance(response.json()[endpoint_type], str) -def test_check_word( - mocker: pytest_mock.MockFixture, +def test_check_word_exists( client: testclient.TestClient, endpoints: conftest.Endpoints, - session: orm.Session, + text_task: text_model.TextTask, ) -> None: """Tests the check word description endpoint.""" - word = text_model.TextTask(word="test", description="mock_description") - session.add(word) - session.flush() - session.commit() - response = client.post( - endpoints.POST_CHECK_WORD.format(word_id=word.id), + endpoints.POST_CHECK_WORD.format(word_id=text_task.id), json={"word": "test", "description": "mock_description"}, ) diff --git a/tests/unit/test_openai.py b/tests/unit/test_openai.py index 6535c48..68f5a5e 100644 --- a/tests/unit/test_openai.py +++ b/tests/unit/test_openai.py @@ -29,7 +29,8 @@ def gpt_instance( return openai.GPT() -def test_gpt_run_method( +@pytest.mark.asyncio() +async def test_gpt_run_method( gpt_instance: openai.GPT, mocker: pytest_mock.MockerFixture, ) -> None: @@ -45,7 +46,10 @@ def test_gpt_run_method( }, ) - actual_response = gpt_instance.run(prompt=user_prompt, system_prompt=system_prompt) + actual_response = await gpt_instance.run( + prompt=user_prompt, + system_prompt=system_prompt, + ) gpt_instance.client.chat.completions.create.assert_called_once_with( # type: ignore[attr-defined] model=gpt_instance.model,