Skip to content

Commit

Permalink
Merge pull request #1 from Still-Human/jinyang
Browse files Browse the repository at this point in the history
Jinyang
  • Loading branch information
jinyang628 authored May 5, 2024
2 parents a5b014f + bfd5d11 commit 164b7e8
Show file tree
Hide file tree
Showing 18 changed files with 267 additions and 334 deletions.
8 changes: 4 additions & 4 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ class InferenceConfig(BaseModel):
"""The main class describing the inference configuration."""

llm_type: dict[Task, LLMType] = {
Task.SUMMARISE: LLMType.GEMINI_PRO,
# Task.SUMMARISE: LLMType.OPENAI_GPT3_5,
# Task.SUMMARISE: LLMType.GEMINI_PRO,
Task.SUMMARISE: LLMType.OPENAI_GPT3_5,
# Task.PRACTICE: LLMType.OPENAI_GPT4_TURBO
# Task.PRACTICE: LLMType.COHERE_COMMAND_R_PLUS
# Task.PRACTICE: LLMType.OPENAI_GPT3_5
Task.PRACTICE: LLMType.GEMINI_PRO
Task.PRACTICE: LLMType.OPENAI_GPT3_5
# Task.PRACTICE: LLMType.GEMINI_PRO
}
45 changes: 24 additions & 21 deletions app/control/post/examiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,42 @@
import re

from app.exceptions.exception import LogicError
from app.llm.model import LLMType
from app.process.types import TODO_MARKER

log = logging.getLogger(__name__)


def post_process(practice: str, llm_type: LLMType) -> tuple[str, str, str]:
"""Post-processes the output of the examiner.
Args:
practice (str): The output of the examiner.
llm_type (LLMType): The type of LLM model used to generate the practice. This is important as certain tokens are added to the output by specific LLM models and need to be removed.
def post_process(language: str, question: str, half_completed_code: str, fully_completed_code: str) -> tuple[str, str, str, str]:
"""Post-processes the practice question and answer blocks to ensure consistency and correctness.
Args:
language (str): The language of the coding question
question (str): The question that is formulated based on the summary.
half_completed_code (str): The half-completed code with the TODO marker in place of the missing code.
fully_completed_code (str): The fully-completed code, with the missing parts annotated by the TODO marker filled.
Returns:
tuple[str, str, str]: The language, question, and answer of the practice generated by the LLM.
tuple[str, str, str, str]: The language, question, half-completed code, and fully-completed code of the practice generated by the LLM.
"""
try:
if not isinstance(practice, str):
raise TypeError(f"Input is not a string: {practice}")
if llm_type == LLMType.CLAUDE_INSTANT_1 or llm_type == LLMType.CLAUDE_3_SONNET:
practice = _remove_output_wrapper(text=practice)
language, block_1, block_2 = _extract_code(text=practice)
question, answer = _determine_question_and_answer(block_1=block_1, block_2=block_2)
question, answer = _verify_expected_similarity_and_difference(question=question, answer=answer)
return (language, question, answer)
if not isinstance(question, str):
raise TypeError(f"Question is not a string: {question}")
if not isinstance(half_completed_code, str):
raise TypeError(f"Half-completed code is not a string: {half_completed_code}")
if not isinstance(fully_completed_code, str):
raise TypeError(f"Fully-completed code is not a string: {fully_completed_code}")
if not isinstance(language, str):
raise TypeError(f"Language is not a string: {language}")
half_completed_code, fully_completed_code = _verify_expected_similarity_and_difference(half_completed_code=half_completed_code, fully_completed_code=fully_completed_code)
return (language, question, half_completed_code, fully_completed_code)
except (ValueError, TypeError) as e:
log.error(f"Error post-processing practice: {e}")
raise LogicError(message=str(e))
except Exception as e:
log.error(f"Unexpected error while post-processing practice: {e}")
raise e

def _verify_expected_similarity_and_difference(question: str, answer: str) -> tuple[str, str]:
def _verify_expected_similarity_and_difference(half_completed_code: str, fully_completed_code: str) -> tuple[str, str]:
"""Verifies that the question and answer blocks are similar before the {TODO_MARKER} and different after the {TODO_MARKER}.
This ensures that our output is streamlined for easy verification by the user.
Expand All @@ -46,8 +49,8 @@ def _verify_expected_similarity_and_difference(question: str, answer: str) -> tu
Returns:
tuple[str, str]: The verified question and answer strings respectively.
"""
question_lines = question.strip().split("\n")
answer_lines = answer.strip().split("\n")
question_lines = half_completed_code.strip().split("\n")
answer_lines = fully_completed_code.strip().split("\n")
todo_marker_found = False

q_index = 0
Expand Down Expand Up @@ -89,7 +92,7 @@ def _verify_expected_similarity_and_difference(question: str, answer: str) -> tu
q_index += 1
a_index += 1

return question, answer
return half_completed_code, fully_completed_code


def _remove_output_wrapper(text: str) -> str:
Expand Down
96 changes: 14 additions & 82 deletions app/control/post/summariser.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
import logging
import re

from app.exceptions.exception import LogicError
from app.llm.model import LLMType

log = logging.getLogger(__name__)


def post_process(summary: str, llm_type: LLMType) -> dict[str, str]:
def post_process(topic: str, content: str) -> dict[str, str]:
"""Processes the output of the summariser and returns a dictionary containing the topic-summary pairs.
Args:
summary (str): The output of the summariser.
topic (str): The topic of the summary.
content (str): The content of the summary.
llm_type (LLMType): The type of LLM model used to generate the summary. This is important as certain tokens are added to the output by specific LLM models and need to be removed.
"""
try:
if not isinstance(summary, str):
raise TypeError(f"Input is not a string: {summary}")
if llm_type == LLMType.CLAUDE_INSTANT_1 or llm_type == LLMType.CLAUDE_3_SONNET:
summary = _remove_output_wrapper(text=summary)
removed_header: str = _remove_header(text=summary)
topic_content_dict: dict[str, str] = _extract_info(text=removed_header)
_reject_unlikely_topics(topic_content_dict)

return topic_content_dict
if not isinstance(topic, str):
raise TypeError(f"Topic is not a string: {topic}")
if not isinstance(content, str):
raise TypeError(f"Content is not a string: {content}")
_reject_unlikely_topics(topic=topic)
return {topic: content}
except (TypeError, ValueError) as e:
log.error(f"Logic error while post-processing summary: {e}")
raise LogicError(message=str(e))
Expand All @@ -32,78 +28,14 @@ def post_process(summary: str, llm_type: LLMType) -> dict[str, str]:
raise e


def _remove_output_wrapper(text: str) -> str:
"""Removes the output wrapper from the text and returns the remaining text.
Args:
text (str): The text to be processed.
"""
index = text.find("</output>")
if index == -1:
raise ValueError(
f"The text does not contain the expected index '<output>': {text}"
)
return text[:index].strip()


def _reject_unlikely_topics(topic_content_dict: dict[str, str]):
"""Throws an error if any of the topics in the dictionary are unlikely to be valid topics.
def _reject_unlikely_topics(topic: str):
"""Throws an error if the topic is unlikely to be valid/of good quality.
The observation is that most valid topics have more than one word. One-word topics generated by LLM tend to be things like "Issue", "Problem", "Solution", etc. that are not what we want.
Args:
topic_content_dict (dict[str, str]): the topic-content dictionary to be checked.
topic (str): the topic-content dictionary to be checked.
"""
if not topic_content_dict:
raise TypeError("No topics found in the dictionary.")

if not isinstance(topic_content_dict, dict):
raise TypeError(f"Input is not a dictionary: {topic_content_dict}")

for key, value in topic_content_dict.items():
if len(key.split(" ")) <= 1:
raise ValueError(f"Topic '{key}' is unlikely to be a valid topic.")


def _remove_header(text: str) -> str:
"""Removes the header from the text and returns the remaining text.
Args:
text (str): The text to be processed.

Returns:
str: The text without the header.
"""
index = text.find("1.")
if index == -1:
raise ValueError(f"The text does not contain the expected index '1.': {text}")
return text[index:].strip()


def _extract_info(text: str) -> dict[str, str]:
"""Extracts the text into topic-summary pairs. Returns a dict containing these pairs.
Example input:
1. **Time Complexity of Insertion**: Inserting an element at the beginning of a Python list is a linear-time operation (O(n)), while appending at the end is constant-time (O(1)).
2. **Prepending Using `insert`**: The `list.insert(0, item)` method can be used to prepend an element, but it's less efficient for large lists.
Example output:
{
'Time Complexity of Insertion': 'Inserting an element at the beginning of a Python list is a linear-time operation (O(n)), while appending at the end is constant-time (O(1)).',
'Prepending Using `insert`': "The `list.insert(0, item)` method can be used to prepend an element, but it's less efficient for large lists."
}
"""
# Regular expression to match **key** and the associated text
pattern = r"\*\*(.+?)\*\*:\s*(.+?)(?=\n\d\.|\Z)"

# Find all matches and build the dictionary
output: dict[str, str] = {
match.group(1): match.group(2).strip()
for match in re.finditer(pattern, text, re.DOTALL)
}

if len(output) == 0:
raise ValueError(f"No matches found in the text: {text}")

return output
if len(topic.split(" ")) <= 1:
raise ValueError(f"Topic '{topic}' is unlikely to be a valid topic.")
3 changes: 3 additions & 0 deletions app/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic import BaseModel

from app.prompts.config import PromptMessageConfig


class LLMConfig(BaseModel):
temperature: float
Expand Down Expand Up @@ -30,6 +32,7 @@ async def send_message(
self,
system_message: str,
user_message: str,
config: PromptMessageConfig
) -> str:
"""Sends a message to the AI and returns the response."""
pass
Expand Down
10 changes: 5 additions & 5 deletions app/llm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from app.llm.cohere import Cohere
from app.llm.google_ai import GoogleAI
from app.llm.llama3 import Llama3
from app.llm.open_ai import OpenAI
from app.llm.open_ai import OpenAi


class LLMType(StrEnum):
Expand All @@ -25,12 +25,12 @@ def default_config(self) -> LLMConfig:
if self == LLMType.OPENAI_GPT4_TURBO:
return LLMConfig(
temperature=1,
max_tokens=4096,
max_tokens=3000,
)
elif self == LLMType.OPENAI_GPT3_5:
return LLMConfig(
temperature=1,
max_tokens=4096,
max_tokens=3000,
)
elif self == LLMType.GEMINI_PRO:
return LLMConfig(
Expand Down Expand Up @@ -79,11 +79,11 @@ def __init__(
model_config: LLMConfig = model_config or model_type.default_config()
match model_type:
case LLMType.OPENAI_GPT4_TURBO:
self._model = OpenAI(
self._model = OpenAi(
model_name=model_type.value, model_config=model_config
)
case LLMType.OPENAI_GPT3_5:
self._model = OpenAI(
self._model = OpenAi(
model_name=model_type.value, model_config=model_config
)
case LLMType.GEMINI_PRO:
Expand Down
76 changes: 60 additions & 16 deletions app/llm/open_ai.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,76 @@
import json
import logging
from typing import Any
from openai import OpenAI
import os

from langchain.schema import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI

from app.exceptions.exception import InferenceFailure
from app.llm.base import LLMBaseModel, LLMConfig
from app.prompts.config import PromptMessageConfig
from app.prompts.examiner.functions import PracticeFunctions, get_practice_functions
from app.prompts.summariser.functions import get_summary_functions, SummaryFunctions

log = logging.getLogger(__name__)

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")

class OpenAI(LLMBaseModel):
class OpenAi(LLMBaseModel):
"""This class handles the interaction with OpenAI API."""


def __init__(self, model_name: str, model_config: LLMConfig):
super().__init__(model_name=model_name, model_config=model_config)
self._model = ChatOpenAI(
model_name=model_name,
temperature=model_config.temperature,
max_tokens=model_config.max_tokens,
self._client = OpenAI(
api_key=OPENAI_API_KEY,
)


async def send_message(self, system_message: str, user_message: str) -> str:
async def send_message(self, system_message: str, user_message: str, config: PromptMessageConfig) -> Any:
"""Sends a message to OpenAI and returns the response."""
messages = [
SystemMessage(content=system_message),
HumanMessage(content=user_message),
]

log.info(f"Sending messages to OpenAI")
response = (await self._model.ainvoke(messages)).content
return response
match config:
case PromptMessageConfig.SUMMARY:
response = self._client.chat.completions.create(
model = self._model_name,
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
functions=get_summary_functions(),
function_call = {"name": SummaryFunctions.GET_SUMMARY}
)
try:
json_response: dict[str, str] = json.loads(response.choices[0].message.function_call.arguments)
topic: str = json_response[SummaryFunctions.TOPIC]
content: str = json_response[SummaryFunctions.CONTENT]
log.info(f"Topic: {topic}, Content: {content}")
return (topic, content)
except Exception as e:
log.error(f"Error processing or receiving OpenAI response: {str(e)}")
raise InferenceFailure("Error processing OpenAI response")
case PromptMessageConfig.PRACTICE:
response = self._client.chat.completions.create(
model = self._model_name,
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
functions=get_practice_functions(),
function_call = {"name": PracticeFunctions.GET_PRACTICE}
)
try:
json_response: dict[str, str] = json.loads(response.choices[0].message.function_call.arguments)
log.info(f"Practice: {json_response}")
language: str = json_response[PracticeFunctions.LANGUAGE]
question: str = json_response[PracticeFunctions.QUESTION]
half_completed_code: str = json_response[PracticeFunctions.HALF_COMPLETED_CODE]
fully_completed_code: str = json_response[PracticeFunctions.FULLY_COMPLETED_CODE]
log.info(f"Language: {language}, Question: {question}, Half-completed-code: {half_completed_code}, Fully-completed-code: {fully_completed_code}")
return (language, question, half_completed_code, fully_completed_code)
except Exception as e:
log.error(f"Error processing or receiving OpenAI response: {str(e)}")
raise InferenceFailure("Error processing OpenAI response")
case _:
raise InferenceFailure("Invalid config type")


Loading

0 comments on commit 164b7e8

Please sign in to comment.