Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyang628 committed May 25, 2024
1 parent b37c2a4 commit 66c7322
Show file tree
Hide file tree
Showing 30 changed files with 569 additions and 474 deletions.
10 changes: 1 addition & 9 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
from pydantic import BaseModel

from app.llm.model import LLMType
from app.models.task import Task


class InferenceConfig(BaseModel):
"""The main class describing the inference configuration."""

llm_type: dict[Task, LLMType] = {
# Task.SUMMARISE: LLMType.GEMINI_PRO,
Task.SUMMARISE: LLMType.OPENAI_GPT3_5,
# Task.PRACTICE: LLMType.OPENAI_GPT4_TURBO
# Task.PRACTICE: LLMType.COHERE_COMMAND_R_PLUS
Task.PRACTICE: LLMType.OPENAI_GPT3_5
# Task.PRACTICE: LLMType.GEMINI_PRO
}
llm_type: LLMType = LLMType.OPENAI_GPT3_5
71 changes: 2 additions & 69 deletions app/control/post/examiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def post_process(language: str, question: str, half_completed_code: str, fully_c
raise TypeError(f"Fully-completed code is not a string: {fully_completed_code}")
if not isinstance(language, str):
raise TypeError(f"Language is not a string: {language}")
half_completed_code = _verify_todo_marker_presence(half_completed_code=half_completed_code)
half_completed_code, fully_completed_code = _verify_expected_similarity_and_difference(half_completed_code=half_completed_code, fully_completed_code=fully_completed_code)

return (language, question, half_completed_code, fully_completed_code)
except (ValueError, TypeError) as e:
log.error(f"Error post-processing practice: {e}")
Expand All @@ -38,62 +37,7 @@ def post_process(language: str, question: str, half_completed_code: str, fully_c
log.error(f"Unexpected error while post-processing practice: {e}")
raise e

def _verify_expected_similarity_and_difference(half_completed_code: str, fully_completed_code: str) -> tuple[str, str]:
"""Verifies that the question and answer blocks are similar before the {TODO_MARKER} and different after the {TODO_MARKER}.
This ensures that our output is streamlined for easy verification by the user.

Args:
question (str): The question block generated by the LLM.
answer (str): The answer block generated by the LLM.
Returns:
tuple[str, str]: The verified question and answer strings respectively.
"""
question_lines = half_completed_code.strip().split("\n")
answer_lines = fully_completed_code.strip().split("\n")
todo_marker_found = False

q_index = 0
a_index = 0

only_comments: bool = True
# Loop through each line until we run out of lines in question
while q_index < len(question_lines):
if "TODO" in question_lines[q_index]:
todo_marker_found = True
q_index += 1
continue # Skip TODO marker line and proceed to enforce matching on subsequent lines

if todo_marker_found:
# Ensure there are enough lines left in the answer to match the remaining question lines
if a_index >= len(answer_lines):
raise ValueError("The answer does not cover all lines in the question after the TODO marker.")

# Check for matching lines strictly after TODO
while a_index < len(answer_lines) and question_lines[q_index] != answer_lines[a_index]:
curr_answer_line: str = answer_lines[a_index].strip()
if not (curr_answer_line.startswith("#") or curr_answer_line.startswith("//") or curr_answer_line == ""):
only_comments = False
a_index += 1 # Skip non-matching lines in the answer until a match is found

if a_index < len(answer_lines) and question_lines[q_index] == answer_lines[a_index]:
if only_comments:
raise ValueError("The user input section contains only comments.")
q_index += 1
a_index += 1 # Increment both to continue matching
else:
raise ValueError("The question and answer blocks differ after the TODO marker.")
else:
# Match lines one-to-one before the TODO marker
if a_index >= len(answer_lines):
raise ValueError("The answer does not cover all lines in the question before the TODO marker.")
if question_lines[q_index] != answer_lines[a_index]:
raise ValueError("The question and answer blocks differ before the TODO marker.")
q_index += 1
a_index += 1

return half_completed_code, fully_completed_code


def _remove_output_wrapper(text: str) -> str:
Expand All @@ -113,18 +57,7 @@ def _remove_output_wrapper(text: str) -> str:
return text[:index].strip()


def _verify_todo_marker_presence(half_completed_code: str) -> str:
"""Verifies that the text contains the {TODO_MARKER}.
Args:
text (str): The text to be processed.
Returns:
str: The text with the {TODO_MARKER} if it is present.
"""
if TODO_MARKER not in half_completed_code:
raise ValueError(f"The text does not contain the placeholder {TODO_MARKER}.")
return half_completed_code


def _determine_question_and_answer(block_1: str, block_2: str) -> tuple[str, str]:
"""Determines which is the question and answer block by checking which block contains the {TODO_MARKER}. Returns the question and answer in order.
Expand Down
164 changes: 164 additions & 0 deletions app/control/post/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import logging
from typing import Any, Optional

from app.exceptions.exception import LogicError
from app.process.types import TODO_MARKER
from app.prompts.generator.functions import NotesFunctions

log = logging.getLogger(__name__)

def post_process(
topic: str,
goal: str,
overview: str,
key_concepts_lst: list[dict[str, str]],
tips_lst: Optional[list[dict[str, str]]],
mcq_practice: Optional[dict[str, Any]],
code_practice: Optional[dict[str, str]]
) -> dict[str, Any]:
"""_summary_
Args:
topic (str): The topic of the revision notes
goal (str): The goal of the revison notes
overview (str): The overview of the revision notes
key_concepts_lst (list[dict[str, str]]): The list of key concepts of the revision notes
tips_lst (list[dict[str, str]]): The list of tips of the revision notes
mcq_practice (Optional[dict[str, Any]]): The multiple-choice question practice of the revision notes
code_practice (Optional[dict[str, str]]): The code practice of the revision notes
Returns:
dict[str, Any]: A dictionary containing the parts of the summary
"""
try:
if not isinstance(topic, str):
raise TypeError(f"Topic is not a string: {topic}")
else:
_reject_unlikely_topics(topic=topic)

if not isinstance(goal, str):
raise TypeError(f"Goal is not a string: {goal}")
if not isinstance(overview, str):
raise TypeError(f"Overview is not a string: {overview}")
if not isinstance(key_concepts_lst, list):
raise TypeError(f"Key concepts list is not a list: {key_concepts_lst}")
if tips_lst:
if not isinstance(tips_lst, list):
raise TypeError(f"Tips list is not a list: {tips_lst}")
if mcq_practice:
if not isinstance(mcq_practice, dict):
raise TypeError(f"MCQ practice is not a dictionary: {mcq_practice}")
if not isinstance(mcq_practice[NotesFunctions.MCQ_PRACTICE_WRONG_OPTIONS.value], list):
raise TypeError(f"MCQ practice wrong options is not a list: {mcq_practice[NotesFunctions.MCQ_PRACTICE_WRONG_OPTIONS.value]}")
if code_practice:
if not isinstance(code_practice, dict):
raise TypeError(f"Code practice is not a dictionary: {code_practice}")

half_completed_code: str = code_practice[NotesFunctions.CODE_PRACTICE_HALF_COMPLETED_CODE.value]
fully_completed_code: str = code_practice[NotesFunctions.CODE_PRACTICE_FULLY_COMPLETED_CODE.value]
_verify_todo_marker_presence(
half_completed_code=half_completed_code
)
_verify_expected_similarity_and_difference(
half_completed_code=half_completed_code,
fully_completed_code=fully_completed_code
)

return {
NotesFunctions.TOPIC.value: topic,
NotesFunctions.GOAL.value: goal,
NotesFunctions.OVERVIEW.value: overview,
NotesFunctions.KEY_CONCEPTS.value: key_concepts_lst,
NotesFunctions.TIPS.value: tips_lst,
NotesFunctions.MCQ_PRACTICE.value: mcq_practice,
NotesFunctions.CODE_PRACTICE.value: code_practice
}
except (TypeError, ValueError) as e:
log.error(f"Logic error while post-processing summary: {e}")
raise LogicError(message=str(e))
except Exception as e:
log.error(f"Unexpected error while post-processing summary: {e}")
raise e


def _reject_unlikely_topics(topic: str):
"""Throws an error if the topic is unlikely to be valid/of good quality.
The observation is that most valid topics have more than one word. One-word topics generated by LLM tend to be things like "Issue", "Problem", "Solution", etc. that are not what we want.
Args:
topic (str): the topic-content dictionary to be checked.
"""

if len(topic.split(" ")) <= 1:
raise ValueError(f"Topic '{topic}' is unlikely to be a valid topic.")

def _enforce_code_language_presence(key_concepts_lst: list[dict[str, str]]):
"""Enforces that the code language is present if the code example is present.
Args:
key_concepts_lst (list[dict[str, str]]): the list of key concepts to be checked.
"""
for key_concept in key_concepts_lst:
code_example: Optional[dict[str, str]] = key_concept.get(NotesFunctions.KEY_CONCEPT_CODE_EXAMPLE.value)
if not code_example:
continue
if code_example.get(NotesFunctions.KEY_CONCEPT_CODE.value) and not code_example.get(NotesFunctions.KEY_CONCEPT_LANGUAGE.value):
raise ValueError(f"Code example present but code language not specified for key concept: {key_concept}")

def _verify_expected_similarity_and_difference(half_completed_code: str, fully_completed_code: str):
"""Verifies that the question and answer blocks are similar before the {TODO_MARKER} and different after the {TODO_MARKER}.
This ensures that our output is streamlined for easy verification by the user.
Args:
question (str): The question block generated by the LLM.
answer (str): The answer block generated by the LLM.
"""
question_lines = half_completed_code.strip().split("\n")
answer_lines = fully_completed_code.strip().split("\n")
todo_marker_found = False

q_index = 0
a_index = 0

only_comments: bool = True
# Loop through each line until we run out of lines in question
while q_index < len(question_lines):
if "TODO" in question_lines[q_index]:
todo_marker_found = True
q_index += 1
continue # Skip TODO marker line and proceed to enforce matching on subsequent lines

if todo_marker_found:
# Ensure there are enough lines left in the answer to match the remaining question lines
if a_index >= len(answer_lines):
raise ValueError("The answer does not cover all lines in the question after the TODO marker.")

# Check for matching lines strictly after TODO
while a_index < len(answer_lines) and question_lines[q_index] != answer_lines[a_index]:
curr_answer_line: str = answer_lines[a_index].strip()
if not (curr_answer_line.startswith("#") or curr_answer_line.startswith("//") or curr_answer_line == ""):
only_comments = False
a_index += 1 # Skip non-matching lines in the answer until a match is found

if a_index < len(answer_lines) and question_lines[q_index] == answer_lines[a_index]:
if only_comments:
raise ValueError("The user input section contains only comments.")
q_index += 1
a_index += 1 # Increment both to continue matching
else:
raise ValueError("The question and answer blocks differ after the TODO marker.")
else:
# Match lines one-to-one before the TODO marker
if a_index >= len(answer_lines):
raise ValueError("The answer does not cover all lines in the question before the TODO marker.")
if question_lines[q_index] != answer_lines[a_index]:
raise ValueError("The question and answer blocks differ before the TODO marker.")
q_index += 1
a_index += 1

def _verify_todo_marker_presence(half_completed_code: str):
"""Verifies that the text contains the {TODO_MARKER}."""
if TODO_MARKER not in half_completed_code:
raise ValueError(f"The text does not contain the placeholder {TODO_MARKER}.")
77 changes: 0 additions & 77 deletions app/control/post/summariser.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
log = logging.getLogger(__name__)

def pre_process(
conversation_dict: dict[str, Any], max_input_tokens: int
conversation: dict[str, Any], max_input_tokens: int
) -> tuple[list[Conversation], int]:
"""Pre-processes the conversation in preparation for summarisation.
Expand All @@ -23,7 +23,7 @@ def pre_process(
"""
try:
conversation_lst, token_sum = _split_by_token_length(
conversation_dict=conversation_dict, max_input_tokens=max_input_tokens
conversation_dict=conversation, max_input_tokens=max_input_tokens
)
return conversation_lst, token_sum
except LogicError as e:
Expand Down
3 changes: 2 additions & 1 deletion app/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import BaseModel

from app.models.content import Content
from app.prompts.config import PromptMessageConfig


Expand Down Expand Up @@ -32,7 +33,7 @@ async def send_message(
self,
system_message: str,
user_message: str,
config: PromptMessageConfig
content_lst: list[Content]
) -> str:
"""Sends a message to the AI and returns the response."""
pass
Expand Down
Loading

0 comments on commit 66c7322

Please sign in to comment.