Skip to content
This repository has been archived by the owner on Dec 31, 2024. It is now read-only.

Commit

Permalink
Add synonym/jeopardy/antonym endpoints (#6)
Browse files Browse the repository at this point in the history
* Add synonym/jeopardy/antonym endpoints

* online postgres debug

* online postgres debug

* Fix faulty ID

* Refactor text task endpoints and schemas
  • Loading branch information
ReinderVosDeWael authored Nov 10, 2023
1 parent 79f13c8 commit 68d83c5
Show file tree
Hide file tree
Showing 11 changed files with 282 additions and 60 deletions.
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ ruff = "^0.1.4"
httpx = "^0.25.1"
pytest-mock = "^3.12.0"
pytest-dotenv = "^0.5.2"
pytest-asyncio = "^0.21.1"

[tool.poetry.group.docs.dependencies]
pdoc = "^14.1.0"
Expand Down
24 changes: 23 additions & 1 deletion src/linguaweb_api/microservices/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module contains interactions with OpenAI models."""
import enum
import logging
from typing import Literal, TypedDict

Expand All @@ -13,6 +14,27 @@
logger = logging.getLogger(LOGGER_NAME)


class Prompts(str, enum.Enum):
"""A class representing the prompts for the GPT model."""

WORD_DESCRIPTION = (
"Return a brief definition for the word provided by the user without using the "
"word (or number, if relevant) in the definition."
)
WORD_SYNONYMS = (
"List synonyms for the following word without using the word (or "
"number, if relevant) at all as a comma separated list"
)
WORD_ANTONYMS = (
"List antonyms for the following word without using the word (or number, if "
"relevant) at all as a comma separated list"
)
WORD_JEOPARDY = (
"Return a very brief Jeopardy!-style description related to the following word "
"without using the word (or number, if relevant) at all"
)


class Message(TypedDict):
"""A message object."""

Expand All @@ -38,7 +60,7 @@ def __init__(self, model: str = "gpt-4-1106-preview") -> None:
self.model = model
self.client = openai.OpenAI(api_key=OPENAI_API_KEY.get_secret_value())

def run(self, *, prompt: str, system_prompt: str) -> str:
async def run(self, *, prompt: str, system_prompt: str) -> str:
"""Runs the GPT model.
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/linguaweb_api/microservices/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self) -> None:
"""
logger.debug("Initializing database.")
db_url = self.get_db_url()
engine_args: dict[str, Any] = {"echo": True}
engine_args: dict[str, Any] = {}
if ENVIRONMENT == "development":
engine_args["connect_args"] = {"check_same_thread": False}
engine_args["poolclass"] = pool.StaticPool
Expand Down
48 changes: 26 additions & 22 deletions src/linguaweb_api/routers/text/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Business logic for the text router."""
import asyncio
import logging

import fastapi
Expand All @@ -15,15 +16,11 @@
logger = logging.getLogger(LOGGER_NAME)


async def get_word_description(
session: orm.Session,
gpt: openai.GPT,
) -> models.TextTask:
async def get_text_task(session: orm.Session) -> models.TextTask:
"""Returns the description of a random word.
Args:
session: The database session.
gpt: The GPT model to use.
Returns:
The description of the word.
Expand All @@ -32,25 +29,32 @@ async def get_word_description(
logger.debug("Checking description.")
word = dictionary.get_random_word()

word_database = session.query(models.TextTask).filter_by(word=word).first()
if not word_database:
logger.debug("Word not found in database.")
word_database = models.TextTask(word=word)
session.add(word_database)

if word_database.description:
logger.debug("Word description found in database.")
return word_database

logger.debug("Word description not found in database.")
system_prompt = (
"Return a brief definition for the word provided by the user without using the "
"word (or number, if relevant) in the definition."
if text_task := session.query(models.TextTask).filter_by(word=word).first():
logger.debug("Text task already exists in database.")
return text_task

logger.debug("Running GPT.")
gpt = openai.GPT()
gpt_calls = [
gpt.run(prompt=word, system_prompt=openai.Prompts.WORD_DESCRIPTION),
gpt.run(prompt=word, system_prompt=openai.Prompts.WORD_SYNONYMS),
gpt.run(prompt=word, system_prompt=openai.Prompts.WORD_ANTONYMS),
gpt.run(prompt=word, system_prompt=openai.Prompts.WORD_JEOPARDY),
]
results = await asyncio.gather(*gpt_calls)

logger.debug("Creating new text task.")
new_text_task = models.TextTask(
word=word,
description=results[0],
synonyms=results[1],
antonyms=results[2],
jeopardy=results[3],
)
description = gpt.run(prompt=word, system_prompt=system_prompt)
word_database.description = description
session.add(new_text_task)
session.commit()
return word_database
logger.debug(new_text_task.id)
return new_text_task


async def check_word(
Expand Down
48 changes: 43 additions & 5 deletions src/linguaweb_api/routers/text/models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,55 @@
"""Models for the text router."""
from typing import Any

import sqlalchemy
from sqlalchemy import orm
from sqlalchemy import orm, types

from linguaweb_api.core import models


class CommaSeparatedList(types.TypeDecorator):
"""A custom SQLAlchemy for comma separated lists."""

impl = sqlalchemy.String(1024)

def process_bind_param(self, value: Any | None, dialect: Any) -> str | None: # noqa: ANN401, ARG002
"""Converts a list of strings to a comma separated string.
Args:
value: The list of strings or a comma separated string.
dialect: The dialect.
Returns:
str | None: The comma separated string.
"""
if value is None:
return None
if isinstance(value, list):
return ",".join(value)
return value

def process_result_value(self, value: Any | None, dialect: Any) -> list[str] | None: # noqa: ARG002, ANN401
"""Converts a comma separated string to a list of strings.
Args:
value: The comma separated string.
dialect: The dialect.
Returns:
list[str]: The list of strings.
"""
if value is None:
return None
return value.split(",")


class TextTask(models.BaseTable):
"""Table for text tasks."""

__tablename__ = "text_tasks"

word: orm.Mapped[str] = orm.mapped_column(sqlalchemy.String(64), unique=True)
description: orm.Mapped[str] = orm.mapped_column(
sqlalchemy.String(1024),
nullable=True,
)
description: orm.Mapped[str] = orm.mapped_column(sqlalchemy.String(1024))
synonyms: orm.Mapped[str] = orm.mapped_column(CommaSeparatedList)
antonyms: orm.Mapped[str] = orm.mapped_column(CommaSeparatedList)
jeopardy: orm.Mapped[str] = orm.mapped_column(sqlalchemy.String(1024))
21 changes: 21 additions & 0 deletions src/linguaweb_api/routers/text/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,27 @@ class WordDescription(pydantic.BaseModel):
description: str


class WordSynonyms(pydantic.BaseModel):
"""Word synonym task."""

id: int
synonyms: list[str]


class WordAntonyms(pydantic.BaseModel):
"""Word antonym task."""

id: int
antonyms: list[str]


class WordJeopardy(pydantic.BaseModel):
"""Word jeopardy task."""

id: int
jeopardy: str


class WordCheck(pydantic.BaseModel):
"""Schema for checking information about a word."""

Expand Down
74 changes: 70 additions & 4 deletions src/linguaweb_api/routers/text/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy import orm

from linguaweb_api.core import config
from linguaweb_api.microservices import openai, sql
from linguaweb_api.microservices import sql
from linguaweb_api.routers.text import controller, schemas

settings = config.get_settings()
Expand All @@ -26,7 +26,6 @@
)
async def get_word_description(
session: orm.Session = fastapi.Depends(sql.get_session),
gpt: openai.GPT = fastapi.Depends(openai.GPT),
) -> schemas.WordDescription:
"""Returns the description of a random word.
Expand All @@ -35,7 +34,72 @@ async def get_word_description(
gpt: The GPT model to use.
"""
logger.debug("Getting word description.")
return await controller.get_word_description(session, gpt)
text_task = await controller.get_text_task(session)
logger.debug("Got word description.")
return text_task


@router.get(
"/synonyms",
response_model=schemas.WordSynonyms,
status_code=status.HTTP_200_OK,
summary="Returns synonyms of a random word.",
description="Returns synonyms of a random word.",
)
async def get_word_synonym(
session: orm.Session = fastapi.Depends(sql.get_session),
) -> schemas.WordSynonyms:
"""Returns synonyms of a random word.
Args:
session: The database session.
"""
logger.debug("Getting word synonym.")
text_task = await controller.get_text_task(session)
logger.debug("Got word synonym.")
return text_task


@router.get(
"/antonyms",
response_model=schemas.WordAntonyms,
status_code=status.HTTP_200_OK,
summary="Returns antonyms of a random word.",
description="Returns antonyms of a random word.",
)
async def get_word_antonym(
session: orm.Session = fastapi.Depends(sql.get_session),
) -> schemas.WordAntonyms:
"""Returns antonyms of a random word.
Args:
session: The database session.
"""
logger.debug("Getting word antonym.")
text_task = await controller.get_text_task(session)
logger.debug("Got word antonym.")
return text_task


@router.get(
"/jeopardy",
response_model=schemas.WordJeopardy,
status_code=status.HTTP_200_OK,
summary="Returns a jeopardy question of a random word.",
description="Returns a jeopardy question of a random word.",
)
async def get_word_jeopardy(
session: orm.Session = fastapi.Depends(sql.get_session),
) -> schemas.WordJeopardy:
"""Returns a jeopardy question of a random word.
Args:
session: The database session.
"""
logger.debug("Getting word jeopardy.")
text_task = await controller.get_text_task(session)
logger.debug("Got word jeopardy.")
return text_task


@router.post(
Expand Down Expand Up @@ -67,4 +131,6 @@ async def check_word(
session: The database session.
"""
logger.debug("Checking word.")
return await controller.check_word(word_id, checks, session)
is_correct = await controller.check_word(word_id, checks, session)
logger.debug("Checked word.")
return is_correct
3 changes: 3 additions & 0 deletions tests/endpoint/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class Endpoints(str, enum.Enum):
"""Enum class that represents the available endpoints for the API."""

GET_DESCRIPTION = f"{API_ROOT}/text/description"
GET_SYNONYMS = f"{API_ROOT}/text/synonyms"
GET_ANTONYMS = f"{API_ROOT}/text/antonyms"
GET_JEOPARDY = f"{API_ROOT}/text/jeopardy"
POST_CHECK_WORD = f"{API_ROOT}/text/check/{{word_id}}"
GET_HEALTH = f"{API_ROOT}/health"

Expand Down
Loading

0 comments on commit 68d83c5

Please sign in to comment.