diff --git a/app/config.py b/app/config.py index 4ea1eb3..84f477c 100644 --- a/app/config.py +++ b/app/config.py @@ -1,17 +1,9 @@ from pydantic import BaseModel from app.llm.model import LLMType -from app.models.task import Task class InferenceConfig(BaseModel): """The main class describing the inference configuration.""" - llm_type: dict[Task, LLMType] = { - # Task.SUMMARISE: LLMType.GEMINI_PRO, - Task.SUMMARISE: LLMType.OPENAI_GPT3_5, - # Task.PRACTICE: LLMType.OPENAI_GPT4_TURBO - # Task.PRACTICE: LLMType.COHERE_COMMAND_R_PLUS - Task.PRACTICE: LLMType.OPENAI_GPT3_5 - # Task.PRACTICE: LLMType.GEMINI_PRO - } + llm_type: LLMType = LLMType.OPENAI_GPT3_5 \ No newline at end of file diff --git a/app/control/post/examiner.py b/app/control/post/examiner.py index dcc39b9..f6aa6c6 100644 --- a/app/control/post/examiner.py +++ b/app/control/post/examiner.py @@ -28,8 +28,7 @@ def post_process(language: str, question: str, half_completed_code: str, fully_c raise TypeError(f"Fully-completed code is not a string: {fully_completed_code}") if not isinstance(language, str): raise TypeError(f"Language is not a string: {language}") - half_completed_code = _verify_todo_marker_presence(half_completed_code=half_completed_code) - half_completed_code, fully_completed_code = _verify_expected_similarity_and_difference(half_completed_code=half_completed_code, fully_completed_code=fully_completed_code) + return (language, question, half_completed_code, fully_completed_code) except (ValueError, TypeError) as e: log.error(f"Error post-processing practice: {e}") @@ -38,62 +37,7 @@ def post_process(language: str, question: str, half_completed_code: str, fully_c log.error(f"Unexpected error while post-processing practice: {e}") raise e -def _verify_expected_similarity_and_difference(half_completed_code: str, fully_completed_code: str) -> tuple[str, str]: - """Verifies that the question and answer blocks are similar before the {TODO_MARKER} and different after the {TODO_MARKER}. - - This ensures that our output is streamlined for easy verification by the user. - Args: - question (str): The question block generated by the LLM. - answer (str): The answer block generated by the LLM. - - Returns: - tuple[str, str]: The verified question and answer strings respectively. - """ - question_lines = half_completed_code.strip().split("\n") - answer_lines = fully_completed_code.strip().split("\n") - todo_marker_found = False - - q_index = 0 - a_index = 0 - - only_comments: bool = True - # Loop through each line until we run out of lines in question - while q_index < len(question_lines): - if "TODO" in question_lines[q_index]: - todo_marker_found = True - q_index += 1 - continue # Skip TODO marker line and proceed to enforce matching on subsequent lines - - if todo_marker_found: - # Ensure there are enough lines left in the answer to match the remaining question lines - if a_index >= len(answer_lines): - raise ValueError("The answer does not cover all lines in the question after the TODO marker.") - - # Check for matching lines strictly after TODO - while a_index < len(answer_lines) and question_lines[q_index] != answer_lines[a_index]: - curr_answer_line: str = answer_lines[a_index].strip() - if not (curr_answer_line.startswith("#") or curr_answer_line.startswith("//") or curr_answer_line == ""): - only_comments = False - a_index += 1 # Skip non-matching lines in the answer until a match is found - - if a_index < len(answer_lines) and question_lines[q_index] == answer_lines[a_index]: - if only_comments: - raise ValueError("The user input section contains only comments.") - q_index += 1 - a_index += 1 # Increment both to continue matching - else: - raise ValueError("The question and answer blocks differ after the TODO marker.") - else: - # Match lines one-to-one before the TODO marker - if a_index >= len(answer_lines): - raise ValueError("The answer does not cover all lines in the question before the TODO marker.") - if question_lines[q_index] != answer_lines[a_index]: - raise ValueError("The question and answer blocks differ before the TODO marker.") - q_index += 1 - a_index += 1 - - return half_completed_code, fully_completed_code def _remove_output_wrapper(text: str) -> str: @@ -113,18 +57,7 @@ def _remove_output_wrapper(text: str) -> str: return text[:index].strip() -def _verify_todo_marker_presence(half_completed_code: str) -> str: - """Verifies that the text contains the {TODO_MARKER}. - - Args: - text (str): The text to be processed. - - Returns: - str: The text with the {TODO_MARKER} if it is present. - """ - if TODO_MARKER not in half_completed_code: - raise ValueError(f"The text does not contain the placeholder {TODO_MARKER}.") - return half_completed_code + def _determine_question_and_answer(block_1: str, block_2: str) -> tuple[str, str]: """Determines which is the question and answer block by checking which block contains the {TODO_MARKER}. Returns the question and answer in order. diff --git a/app/control/post/generator.py b/app/control/post/generator.py new file mode 100644 index 0000000..e9ed6b4 --- /dev/null +++ b/app/control/post/generator.py @@ -0,0 +1,164 @@ +import logging +from typing import Any, Optional + +from app.exceptions.exception import LogicError +from app.process.types import TODO_MARKER +from app.prompts.generator.functions import NotesFunctions + +log = logging.getLogger(__name__) + +def post_process( + topic: str, + goal: str, + overview: str, + key_concepts_lst: list[dict[str, str]], + tips_lst: Optional[list[dict[str, str]]], + mcq_practice: Optional[dict[str, Any]], + code_practice: Optional[dict[str, str]] + ) -> dict[str, Any]: + """_summary_ + + Args: + topic (str): The topic of the revision notes + goal (str): The goal of the revison notes + overview (str): The overview of the revision notes + key_concepts_lst (list[dict[str, str]]): The list of key concepts of the revision notes + tips_lst (list[dict[str, str]]): The list of tips of the revision notes + mcq_practice (Optional[dict[str, Any]]): The multiple-choice question practice of the revision notes + code_practice (Optional[dict[str, str]]): The code practice of the revision notes + + Returns: + dict[str, Any]: A dictionary containing the parts of the summary + """ + try: + if not isinstance(topic, str): + raise TypeError(f"Topic is not a string: {topic}") + else: + _reject_unlikely_topics(topic=topic) + + if not isinstance(goal, str): + raise TypeError(f"Goal is not a string: {goal}") + if not isinstance(overview, str): + raise TypeError(f"Overview is not a string: {overview}") + if not isinstance(key_concepts_lst, list): + raise TypeError(f"Key concepts list is not a list: {key_concepts_lst}") + if tips_lst: + if not isinstance(tips_lst, list): + raise TypeError(f"Tips list is not a list: {tips_lst}") + if mcq_practice: + if not isinstance(mcq_practice, dict): + raise TypeError(f"MCQ practice is not a dictionary: {mcq_practice}") + if not isinstance(mcq_practice[NotesFunctions.MCQ_PRACTICE_WRONG_OPTIONS.value], list): + raise TypeError(f"MCQ practice wrong options is not a list: {mcq_practice[NotesFunctions.MCQ_PRACTICE_WRONG_OPTIONS.value]}") + if code_practice: + if not isinstance(code_practice, dict): + raise TypeError(f"Code practice is not a dictionary: {code_practice}") + + half_completed_code: str = code_practice[NotesFunctions.CODE_PRACTICE_HALF_COMPLETED_CODE.value] + fully_completed_code: str = code_practice[NotesFunctions.CODE_PRACTICE_FULLY_COMPLETED_CODE.value] + _verify_todo_marker_presence( + half_completed_code=half_completed_code + ) + _verify_expected_similarity_and_difference( + half_completed_code=half_completed_code, + fully_completed_code=fully_completed_code + ) + + return { + NotesFunctions.TOPIC.value: topic, + NotesFunctions.GOAL.value: goal, + NotesFunctions.OVERVIEW.value: overview, + NotesFunctions.KEY_CONCEPTS.value: key_concepts_lst, + NotesFunctions.TIPS.value: tips_lst, + NotesFunctions.MCQ_PRACTICE.value: mcq_practice, + NotesFunctions.CODE_PRACTICE.value: code_practice + } + except (TypeError, ValueError) as e: + log.error(f"Logic error while post-processing summary: {e}") + raise LogicError(message=str(e)) + except Exception as e: + log.error(f"Unexpected error while post-processing summary: {e}") + raise e + + +def _reject_unlikely_topics(topic: str): + """Throws an error if the topic is unlikely to be valid/of good quality. + + The observation is that most valid topics have more than one word. One-word topics generated by LLM tend to be things like "Issue", "Problem", "Solution", etc. that are not what we want. + + Args: + topic (str): the topic-content dictionary to be checked. + """ + + if len(topic.split(" ")) <= 1: + raise ValueError(f"Topic '{topic}' is unlikely to be a valid topic.") + +def _enforce_code_language_presence(key_concepts_lst: list[dict[str, str]]): + """Enforces that the code language is present if the code example is present. + + Args: + key_concepts_lst (list[dict[str, str]]): the list of key concepts to be checked. + """ + for key_concept in key_concepts_lst: + code_example: Optional[dict[str, str]] = key_concept.get(NotesFunctions.KEY_CONCEPT_CODE_EXAMPLE.value) + if not code_example: + continue + if code_example.get(NotesFunctions.KEY_CONCEPT_CODE.value) and not code_example.get(NotesFunctions.KEY_CONCEPT_LANGUAGE.value): + raise ValueError(f"Code example present but code language not specified for key concept: {key_concept}") + +def _verify_expected_similarity_and_difference(half_completed_code: str, fully_completed_code: str): + """Verifies that the question and answer blocks are similar before the {TODO_MARKER} and different after the {TODO_MARKER}. + + This ensures that our output is streamlined for easy verification by the user. + + Args: + question (str): The question block generated by the LLM. + answer (str): The answer block generated by the LLM. + """ + question_lines = half_completed_code.strip().split("\n") + answer_lines = fully_completed_code.strip().split("\n") + todo_marker_found = False + + q_index = 0 + a_index = 0 + + only_comments: bool = True + # Loop through each line until we run out of lines in question + while q_index < len(question_lines): + if "TODO" in question_lines[q_index]: + todo_marker_found = True + q_index += 1 + continue # Skip TODO marker line and proceed to enforce matching on subsequent lines + + if todo_marker_found: + # Ensure there are enough lines left in the answer to match the remaining question lines + if a_index >= len(answer_lines): + raise ValueError("The answer does not cover all lines in the question after the TODO marker.") + + # Check for matching lines strictly after TODO + while a_index < len(answer_lines) and question_lines[q_index] != answer_lines[a_index]: + curr_answer_line: str = answer_lines[a_index].strip() + if not (curr_answer_line.startswith("#") or curr_answer_line.startswith("//") or curr_answer_line == ""): + only_comments = False + a_index += 1 # Skip non-matching lines in the answer until a match is found + + if a_index < len(answer_lines) and question_lines[q_index] == answer_lines[a_index]: + if only_comments: + raise ValueError("The user input section contains only comments.") + q_index += 1 + a_index += 1 # Increment both to continue matching + else: + raise ValueError("The question and answer blocks differ after the TODO marker.") + else: + # Match lines one-to-one before the TODO marker + if a_index >= len(answer_lines): + raise ValueError("The answer does not cover all lines in the question before the TODO marker.") + if question_lines[q_index] != answer_lines[a_index]: + raise ValueError("The question and answer blocks differ before the TODO marker.") + q_index += 1 + a_index += 1 + +def _verify_todo_marker_presence(half_completed_code: str): + """Verifies that the text contains the {TODO_MARKER}.""" + if TODO_MARKER not in half_completed_code: + raise ValueError(f"The text does not contain the placeholder {TODO_MARKER}.") \ No newline at end of file diff --git a/app/control/post/summariser.py b/app/control/post/summariser.py deleted file mode 100644 index 0ca52ca..0000000 --- a/app/control/post/summariser.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -from typing import Any, Optional - -from app.exceptions.exception import LogicError -from app.prompts.summariser.functions import SummaryFunctions - -log = logging.getLogger(__name__) - - -def post_process( - topic: str, - goal: str, - overview: str, - key_concepts_lst: list[dict[str, str]] - ) -> dict[str, Any]: - """_summary_ - - Args: - topic (str): The topic of the summary - goal (str): The goal of the summary - overview (str): The overview of the summary - key_concepts_lst (list[dict[str, str]]): The list of key concepts of the summary - - Returns: - dict[str, Any]: A dictionary containing the parts of the summary - """ - try: - if not isinstance(topic, str): - raise TypeError(f"Topic is not a string: {topic}") - if not isinstance(goal, str): - raise TypeError(f"Goal is not a string: {goal}") - if not isinstance(overview, str): - raise TypeError(f"Overview is not a string: {overview}") - if not isinstance(key_concepts_lst, list): - raise TypeError(f"Key concepts list is not a list: {key_concepts_lst}") - - _reject_unlikely_topics(topic=topic) - _enforce_code_language_presence(key_concepts_lst=key_concepts_lst) - - return { - SummaryFunctions.TOPIC.value: topic, - SummaryFunctions.GOAL.value: goal, - SummaryFunctions.OVERVIEW.value: overview, - SummaryFunctions.KEY_CONCEPTS.value: key_concepts_lst - } - except (TypeError, ValueError) as e: - log.error(f"Logic error while post-processing summary: {e}") - raise LogicError(message=str(e)) - except Exception as e: - log.error(f"Unexpected error while post-processing summary: {e}") - raise e - - -def _reject_unlikely_topics(topic: str): - """Throws an error if the topic is unlikely to be valid/of good quality. - - The observation is that most valid topics have more than one word. One-word topics generated by LLM tend to be things like "Issue", "Problem", "Solution", etc. that are not what we want. - - Args: - topic (str): the topic-content dictionary to be checked. - """ - - if len(topic.split(" ")) <= 1: - raise ValueError(f"Topic '{topic}' is unlikely to be a valid topic.") - -def _enforce_code_language_presence(key_concepts_lst: list[dict[str, str]]): - """Enforces that the code language is present if the code example is present. - - Args: - key_concepts_lst (list[dict[str, str]]): the list of key concepts to be checked. - """ - for key_concept in key_concepts_lst: - code_example: Optional[dict[str, str]] = key_concept.get(SummaryFunctions.CODE_EXAMPLE.value) - if not code_example: - continue - if code_example.get(SummaryFunctions.CODE.value) and not code_example.get(SummaryFunctions.LANGUAGE.value): - raise ValueError(f"Code example present but code language not specified for key concept: {key_concept}") \ No newline at end of file diff --git a/app/control/pre/summariser.py b/app/control/pre/generator.py similarity index 96% rename from app/control/pre/summariser.py rename to app/control/pre/generator.py index 6cfbebd..6b0c2d8 100644 --- a/app/control/pre/summariser.py +++ b/app/control/pre/generator.py @@ -10,7 +10,7 @@ log = logging.getLogger(__name__) def pre_process( - conversation_dict: dict[str, Any], max_input_tokens: int + conversation: dict[str, Any], max_input_tokens: int ) -> tuple[list[Conversation], int]: """Pre-processes the conversation in preparation for summarisation. @@ -23,7 +23,7 @@ def pre_process( """ try: conversation_lst, token_sum = _split_by_token_length( - conversation_dict=conversation_dict, max_input_tokens=max_input_tokens + conversation_dict=conversation, max_input_tokens=max_input_tokens ) return conversation_lst, token_sum except LogicError as e: diff --git a/app/llm/base.py b/app/llm/base.py index 0569558..2c7c9fb 100644 --- a/app/llm/base.py +++ b/app/llm/base.py @@ -3,6 +3,7 @@ from pydantic import BaseModel +from app.models.content import Content from app.prompts.config import PromptMessageConfig @@ -32,7 +33,7 @@ async def send_message( self, system_message: str, user_message: str, - config: PromptMessageConfig + content_lst: list[Content] ) -> str: """Sends a message to the AI and returns the response.""" pass diff --git a/app/llm/open_ai.py b/app/llm/open_ai.py index 7c6c342..479900b 100644 --- a/app/llm/open_ai.py +++ b/app/llm/open_ai.py @@ -6,9 +6,9 @@ from app.exceptions.exception import InferenceFailure from app.llm.base import LLMBaseModel, LLMConfig +from app.models.content import Content from app.prompts.config import PromptMessageConfig -from app.prompts.examiner.functions import PracticeFunctions, get_practice_functions -from app.prompts.summariser.functions import get_summary_functions, SummaryFunctions +from app.prompts.generator.functions import get_notes_functions, NotesFunctions log = logging.getLogger(__name__) @@ -24,73 +24,84 @@ def __init__(self, model_name: str, model_config: LLMConfig): api_key=OPENAI_API_KEY, ) - async def send_message(self, system_message: str, user_message: str, config: PromptMessageConfig) -> Any: + async def send_message(self, system_message: str, user_message: str, content_lst: list[Content]) -> Any: """Sends a message to OpenAI and returns the response.""" log.info(f"Sending messages to OpenAI") - match config: - case PromptMessageConfig.SUMMARY: - response = self._client.chat.completions.create( - model = self._model_name, - messages = [ - {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} - ], - functions=get_summary_functions(), - function_call = {"name": SummaryFunctions.GET_SUMMARY} - ) - try: - json_response: dict[str, str] = json.loads(response.choices[0].message.function_call.arguments) - print("llm_response") - print(json_response) - topic: str = json_response[SummaryFunctions.TOPIC] - goal: str = json_response[SummaryFunctions.GOAL] - overview: str = json_response[SummaryFunctions.OVERVIEW] - key_concepts_lst: list = [] - for key_concept in json_response[SummaryFunctions.KEY_CONCEPTS]: - code_example: Optional[dict[str, str]] = key_concept.get(SummaryFunctions.CODE_EXAMPLE) - if code_example: - key_concepts_lst.append({ - SummaryFunctions.TITLE.value: key_concept[SummaryFunctions.TITLE], - SummaryFunctions.EXPLANATION.value: key_concept[SummaryFunctions.EXPLANATION], - SummaryFunctions.CODE_EXAMPLE.value: { - SummaryFunctions.CODE.value: code_example[SummaryFunctions.CODE], - SummaryFunctions.LANGUAGE.value: code_example[SummaryFunctions.LANGUAGE] - } - }) - else: - key_concepts_lst.append({ - SummaryFunctions.TITLE.value: key_concept[SummaryFunctions.TITLE], - SummaryFunctions.EXPLANATION.value: key_concept[SummaryFunctions.EXPLANATION], - }) - log.info(f"Topic: {topic}, Goal: {goal} Overview: {overview}, Key concepts: {key_concepts_lst}") - return (topic, goal, overview, key_concepts_lst) - except Exception as e: - log.error(f"Error processing or receiving OpenAI response: {str(e)}") - raise InferenceFailure("Error processing OpenAI response") - case PromptMessageConfig.PRACTICE: - response = self._client.chat.completions.create( - model = self._model_name, - messages = [ - {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} - ], - functions=get_practice_functions(), - function_call = {"name": PracticeFunctions.GET_PRACTICE} - ) - try: - json_response: dict[str, str] = json.loads(response.choices[0].message.function_call.arguments) - log.info(f"Practice: {json_response}") - language: str = json_response[PracticeFunctions.LANGUAGE] - question: str = json_response[PracticeFunctions.QUESTION] - half_completed_code: str = json_response[PracticeFunctions.HALF_COMPLETED_CODE] - fully_completed_code: str = json_response[PracticeFunctions.FULLY_COMPLETED_CODE] - log.info(f"Language: {language}, Question: {question}, Half-completed-code: {half_completed_code}, Fully-completed-code: {fully_completed_code}") - return (language, question, half_completed_code, fully_completed_code) - except Exception as e: - log.error(f"Error processing or receiving OpenAI response: {str(e)}") - raise InferenceFailure("Error processing OpenAI response") - case _: - raise InferenceFailure("Invalid config type") + try: + response = self._client.chat.completions.create( + model = self._model_name, + messages = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_message} + ], + functions=get_notes_functions( + contains_mcq_practice=bool(Content.MCQ in content_lst), + contains_code_practice=bool(Content.CODE in content_lst) + ), + function_call = {"name": NotesFunctions.GET_NOTES} + ) + try: + json_response: dict[str, str] = json.loads(response.choices[0].message.function_call.arguments) + print("~~~LLM RESPONSE~~~") + print(json_response) + topic: str = json_response[NotesFunctions.TOPIC] + goal: str = json_response[NotesFunctions.GOAL] + overview: str = json_response[NotesFunctions.OVERVIEW] + + key_concepts_lst: list = [] + for key_concept in json_response[NotesFunctions.KEY_CONCEPTS]: + code_example: Optional[dict[str, str]] = key_concept.get(NotesFunctions.KEY_CONCEPT_CODE_EXAMPLE) + if code_example: + key_concepts_lst.append({ + NotesFunctions.KEY_CONCEPT_TITLE.value: key_concept[NotesFunctions.KEY_CONCEPT_TITLE], + NotesFunctions.KEY_CONCEPT_EXPLANATION.value: key_concept[NotesFunctions.KEY_CONCEPT_EXPLANATION], + NotesFunctions.KEY_CONCEPT_CODE_EXAMPLE.value: { + NotesFunctions.KEY_CONCEPT_CODE.value: code_example[NotesFunctions.KEY_CONCEPT_CODE], + NotesFunctions.KEY_CONCEPT_LANGUAGE.value: code_example[NotesFunctions.KEY_CONCEPT_LANGUAGE] + } + }) + else: + key_concepts_lst.append({ + NotesFunctions.KEY_CONCEPT_TITLE.value: key_concept[NotesFunctions.KEY_CONCEPT_TITLE], + NotesFunctions.KEY_CONCEPT_EXPLANATION.value: key_concept[NotesFunctions.KEY_CONCEPT_EXPLANATION], + }) + + tips_lst: list = [] + tips: Optional[list[dict[str, str]]] = json_response.get(NotesFunctions.TIPS) + if tips: + for tip in tips: + tips_lst.append({ + NotesFunctions.TIP_TITLE.value: tip[NotesFunctions.TIP_TITLE], + NotesFunctions.TIP_EXPLANATION.value: tip[NotesFunctions.TIP_EXPLANATION] + }) + + mcq_practice: Optional[dict[str, str]] = json_response.get(NotesFunctions.MCQ_PRACTICE) + if mcq_practice: + mcq_practice = { + NotesFunctions.MCQ_PRACTICE_TITLE.value: mcq_practice[NotesFunctions.MCQ_PRACTICE_TITLE], + NotesFunctions.MCQ_PRACTICE_QUESTION.value: mcq_practice[NotesFunctions.MCQ_PRACTICE_QUESTION], + NotesFunctions.MCQ_PRACTICE_WRONG_OPTIONS.value: mcq_practice[NotesFunctions.MCQ_PRACTICE_WRONG_OPTIONS], + NotesFunctions.MCQ_PRACTICE_CORRECT_OPTION.value: mcq_practice[NotesFunctions.MCQ_PRACTICE_CORRECT_OPTION] + } + + code_practice: Optional[dict[str, str]] = json_response.get(NotesFunctions.CODE_PRACTICE) + if code_practice: + code_practice = { + NotesFunctions.CODE_PRACTICE_TITLE.value: code_practice[NotesFunctions.CODE_PRACTICE_TITLE], + NotesFunctions.CODE_PRACTICE_QUESTION.value: code_practice[NotesFunctions.CODE_PRACTICE_QUESTION], + NotesFunctions.CODE_PRACTICE_HALF_COMPLETED_CODE.value: code_practice[NotesFunctions.CODE_PRACTICE_HALF_COMPLETED_CODE], + NotesFunctions.CODE_PRACTICE_FULLY_COMPLETED_CODE.value: code_practice[NotesFunctions.CODE_PRACTICE_FULLY_COMPLETED_CODE], + NotesFunctions.CODE_PRACTICE_LANGUAGE.value: code_practice[NotesFunctions.CODE_PRACTICE_LANGUAGE] + } + + log.info(f"Topic: {topic}, Goal: {goal} Overview: {overview}, Key concepts: {key_concepts_lst}, Tips: {tips_lst}, MCQ Practice: {mcq_practice}, Code Practice: {code_practice}") + return (topic, goal, overview, key_concepts_lst, tips_lst, mcq_practice, code_practice) + except Exception as e: + log.error(f"Error processing or receiving OpenAI response: {str(e)}") + raise InferenceFailure("Error processing OpenAI response") + except Exception as e: + log.error(f"Error sending message to OpenAI: {str(e)}") + raise InferenceFailure("Error sending message to OpenAI") diff --git a/app/main.py b/app/main.py index adcde7a..e2b0f7b 100644 --- a/app/main.py +++ b/app/main.py @@ -1,14 +1,12 @@ import logging -from typing import Any, Optional from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from app.exceptions.exception import InferenceFailure, LogicError from app.models.inference import InferenceInput -from app.models.task import Task -from app.scripts.practice import generate_practice -from app.scripts.summary import generate_summary +from app.models.content import Content +from app.scripts.generate import generate log = logging.getLogger(__name__) @@ -26,26 +24,15 @@ async def generate_notes(input: InferenceInput) -> JSONResponse: JSONResponse: The generated notes that will be propagated back to Stomach upon successful inference. """ try: - tasks: list[str] = input.tasks - validated_tasks: list[Task] = Task.validate(tasks) - - summary: Optional[list[dict[str, str]]] = None - practice: Optional[list[dict[str, str]]] = None - for task in validated_tasks: - if task == Task.SUMMARISE: - summary, token_sum = await generate_summary( - conversations=input.conversation - ) - elif task == Task.PRACTICE: - if not summary: - summary, token_sum = await generate_summary( - conversations=input.conversation - ) - practice: dict[str, Any] = await generate_practice(summary=summary) + content: list[str] = input.content + validated_content_lst: list[Content] = Content.validate(content_str_lst=content) + result, token_sum = await generate( + conversation=input.conversation, + content_lst=validated_content_lst + ) return JSONResponse( status_code=200, - # TODO: Wrap everything under result - content={"result": summary, "token_sum": token_sum}, + content={"result": result, "token_sum": token_sum}, ) except LogicError as e: log.error(f"Logic error while trying to generate notes: {str(e)}") @@ -55,12 +42,5 @@ async def generate_notes(input: InferenceInput) -> JSONResponse: log.error(f"Error in generating notes: {str(e)}") # Raise exception only when an unexpected error occurs. If not, try to return good results as much as possible. raise HTTPException(status_code=500, detail=str(e)) - - # Returns the parts that have been successfully processed. - if summary or practice: - return JSONResponse( - status_code=200, - # TODO: Wrap everything under result - content={"result": summary, "token_sum": token_sum}, - ) + raise HTTPException(status_code=400, detail="Failed to generate notes completely.") diff --git a/app/models/content.py b/app/models/content.py new file mode 100644 index 0000000..f7490b0 --- /dev/null +++ b/app/models/content.py @@ -0,0 +1,30 @@ +import logging +from enum import StrEnum + +from app.exceptions.exception import LogicError + +log = logging.getLogger(__name__) + +# This class must match Content Enum from fingers and stomach repo +class Content(StrEnum): + MCQ = "mcq" + CODE = "code" + + def validate(content_str_lst: list[str]) -> list["Content"]: + """Validates the content strings and transforms them into Content objects. + + Returns: + list[Content]: The validated content. + """ + validated_content: list[Content] = [] + for content_str in content_str_lst: + try: + validated_content.append(Content(content_str)) + except KeyError as e: + log.error(f"Error validating content because enum string is wrong: {e}") + raise LogicError(f"Invalid content string: {content_str}") from e + except Exception as e: + log.error(f"Unexpected error while validating content: {e}") + raise e + return validated_content + diff --git a/app/models/inference.py b/app/models/inference.py index e0c37aa..7a35bed 100644 --- a/app/models/inference.py +++ b/app/models/inference.py @@ -5,4 +5,4 @@ class InferenceInput(BaseModel): conversation: dict[str, Any] - tasks: list[str] + content: list[str] diff --git a/app/models/task.py b/app/models/task.py deleted file mode 100644 index d070b31..0000000 --- a/app/models/task.py +++ /dev/null @@ -1,30 +0,0 @@ -import logging -from enum import StrEnum - -from app.exceptions.exception import LogicError - -log = logging.getLogger(__name__) - -# This class must match TaskEnum from fingers and stomach repo -class Task(StrEnum): - SUMMARISE = "summarise" - PRACTICE = "practice" - - def validate(task_str_lst: list[str]) -> list["Task"]: - """Validates whether the task strings and transforms them into Task objects. - - Returns: - list[Task]: The validated tasks. - """ - validated_tasks: list[Task] = [] - for task_str in task_str_lst: - try: - validated_tasks.append(Task(task_str)) - except KeyError as e: - log.error(f"Error validating task because enum string is wrong: {e}") - raise LogicError(f"Invalid task string: {task_str}") from e - except Exception as e: - log.error(f"Unexpected error while validating tasks: {e}") - raise e - return validated_tasks - diff --git a/app/process/examiner.py b/app/process/examiner.py index 54365bd..a288751 100644 --- a/app/process/examiner.py +++ b/app/process/examiner.py @@ -5,7 +5,7 @@ from app.exceptions.exception import InferenceFailure, LogicError from app.llm.base import LLMBaseModel from app.llm.model import LLM, LLMType -from app.models.task import Task +from app.models.content import Task from app.prompts.config import PromptMessageConfig from app.prompts.examiner.anthropic import ( generate_anthropic_examiner_system_message, diff --git a/app/process/summariser.py b/app/process/generator.py similarity index 76% rename from app/process/summariser.py rename to app/process/generator.py index 74e524e..efdc1d8 100644 --- a/app/process/summariser.py +++ b/app/process/generator.py @@ -2,41 +2,38 @@ from typing import Any from app.config import InferenceConfig -from app.control.post.summariser import post_process -from app.control.pre.summariser import pre_process +from app.control.post.generator import post_process +from app.control.pre.generator import pre_process from app.exceptions.exception import InferenceFailure, LogicError from app.llm.base import LLMBaseModel from app.llm.model import LLM, LLMType from app.models.conversation import Conversation -from app.models.task import Task -from app.prompts.config import PromptMessageConfig -from app.prompts.summariser.anthropic import ( +from app.models.content import Content +from app.prompts.generator.anthropic import ( generate_anthropic_summariser_system_message, generate_anthropic_summariser_user_message) -from app.prompts.summariser.cohere import ( +from app.prompts.generator.cohere import ( generate_cohere_summariser_system_message, generate_cohere_summariser_user_message) -from app.prompts.summariser.google_ai import ( +from app.prompts.generator.google_ai import ( generate_google_ai_summariser_system_message, generate_google_ai_summariser_user_message) -from app.prompts.summariser.llama3 import generate_llama3_summariser_system_message, generate_llama3_summariser_user_message -from app.prompts.summariser.open_ai import ( +from app.prompts.generator.llama3 import generate_llama3_summariser_system_message, generate_llama3_summariser_user_message +from app.prompts.generator.open_ai import ( generate_open_ai_summariser_system_message, generate_open_ai_summariser_user_message) log = logging.getLogger(__name__) -class Summariser: - - TASK = Task.SUMMARISE +class Generator: _llm_type: LLMType _model: LLMBaseModel _max_tokens: int def __init__(self, config: InferenceConfig): - self._llm_type = config.llm_type.get(self.TASK) + self._llm_type = config.llm_type self._model = LLM(model_type=self._llm_type).model self._max_tokens = self._model.model_config.max_tokens @@ -95,43 +92,46 @@ def generate_user_message(self, conversation: Conversation) -> str: ) def pre_process( - self, conversation_dict: dict[str, Any] + self, conversation: dict[str, Any] ) -> tuple[list[Conversation], int]: """Pre-processes the conversation given the summarisation model's max tokens limit. The conversation will be split up into multiple conversation chunks if it exceeds the max tokens limit. Args: - conversation_dict (dict[str, Any]): The user's conversation chatlog. + conversation (dict[str, Any]): The user's conversation chatlog. Returns: tuple[list[Conversation], int]: The list of conversation chunks and the total token sum of the conversation. """ conversation_lst, token_sum = pre_process( - conversation_dict=conversation_dict, max_input_tokens=self._max_tokens + conversation=conversation, max_input_tokens=self._max_tokens ) log.info(f"Length of conversation list: {len(conversation_lst)} post split") log.info(f"Token sum of conversation: {token_sum}") return conversation_lst, token_sum - async def summarise(self, conversation: Conversation) -> dict[str, Any]: - """Invokes the LLM to generate a summary of the conversation. + async def generate(self, conversation: Conversation, content_lst: list[Content]) -> dict[str, Any]: + """Invokes the LLM to generate revision notes from the conversation. Args: - conversation (Conversation): The conversation to be summarised. + conversation (Conversation): The conversation to generate revision notes of. + content_lst (list[Content]): The content types that the user wants to generate notes for. Returns: - dict[str, str]: A dictionary containing the topic-content of the summary + dict[str, str]: A dictionary containing the content of the revision notes """ system_message: str = self.generate_system_message() user_message: str = self.generate_user_message(conversation=conversation) try: - topic, goal, overview, key_concepts_lst = await self._model.send_message( - system_message=system_message, user_message=user_message, config=PromptMessageConfig.SUMMARY + topic, goal, overview, key_concepts_lst, tips_lst, mcq_practice, code_practice = await self._model.send_message( + system_message=system_message, + user_message=user_message, + content_lst=content_lst ) processed_summary: dict[str, Any] = post_process( - topic=topic, goal=goal, overview=overview, key_concepts_lst=key_concepts_lst + topic=topic, goal=goal, overview=overview, key_concepts_lst=key_concepts_lst, tips_lst=tips_lst, mcq_practice=mcq_practice, code_practice=code_practice ) log.info(f"Processed Summary: {processed_summary}") return processed_summary diff --git a/app/prompts/examiner/functions.py b/app/prompts/examiner/functions.py deleted file mode 100644 index 207c7c0..0000000 --- a/app/prompts/examiner/functions.py +++ /dev/null @@ -1,40 +0,0 @@ -from enum import StrEnum -from typing import Any - -class PracticeFunctions(StrEnum): - GET_PRACTICE = "get_practice" - LANGUAGE = "language" - QUESTION = "question" - HALF_COMPLETED_CODE = "half_completed_code" - FULLY_COMPLETED_CODE = "fully_completed_code" - -def get_practice_functions() -> list[dict[str, str]]: - practice_functions: list[dict[str, Any]]= [ - { - "name": PracticeFunctions.GET_PRACTICE, - "description": "Generate practice questions based on the summary.", - "parameters": { - "type": "object", - "properties": { - PracticeFunctions.QUESTION: { - "type": "string", - "description": "The coding question that is formulated based on the summary, with enough context and hints for the student to complete the code without ambiguity." - }, - PracticeFunctions.HALF_COMPLETED_CODE: { - "type": "string", - "description": "The half-completed code with the TODO marker in place of the missing code." - }, - PracticeFunctions.FULLY_COMPLETED_CODE: { - "type": "string", - "description": "The fully-completed code, with the missing parts annotated by the TODO marker filled." - }, - PracticeFunctions.LANGUAGE: { - "type": "string", - "description": "The programming language used in the practice question." - } - }, - "required": [PracticeFunctions.QUESTION, PracticeFunctions.HALF_COMPLETED_CODE, PracticeFunctions.FULLY_COMPLETED_CODE, PracticeFunctions.LANGUAGE] - } - } - ] - return practice_functions \ No newline at end of file diff --git a/app/prompts/summariser/__init__.py b/app/prompts/generator/__init__.py similarity index 100% rename from app/prompts/summariser/__init__.py rename to app/prompts/generator/__init__.py diff --git a/app/prompts/summariser/anthropic.py b/app/prompts/generator/anthropic.py similarity index 100% rename from app/prompts/summariser/anthropic.py rename to app/prompts/generator/anthropic.py diff --git a/app/prompts/summariser/cohere.py b/app/prompts/generator/cohere.py similarity index 100% rename from app/prompts/summariser/cohere.py rename to app/prompts/generator/cohere.py diff --git a/app/prompts/generator/functions.py b/app/prompts/generator/functions.py new file mode 100644 index 0000000..f80c5e1 --- /dev/null +++ b/app/prompts/generator/functions.py @@ -0,0 +1,200 @@ +from enum import StrEnum +from typing import Any + +class NotesFunctions(StrEnum): + GET_NOTES = "get_notes" + + # Unique element to output + TOPIC = "topic" # Compulsory + GOAL = "goal" # Compulsory + OVERVIEW = "overview" # Compulsory + + # List element to output + KEY_CONCEPTS = "key_concepts" # Compulsory + # List of tuples containing these 3 elements + KEY_CONCEPT_TITLE = "key_concept_title" # Compulsory + KEY_CONCEPT_EXPLANATION = "key_concept_explanation" # Compulsory + KEY_CONCEPT_CODE_EXAMPLE = "key_concept_code_example" # Optional + # KEY_CONCEPT_CODE_EXAMPLE contains these 2 elements + KEY_CONCEPT_CODE = "key_concept_code" # Compulsory + KEY_CONCEPT_LANGUAGE = "key_concept_language" # Compulsory + + # List element to output + TIPS = "tips" # Optional + # List of tuples containing these 2 elements + TIP_TITLE = "tip_title" # Compulsory + TIP_EXPLANATION = "tip_explanation" # Compulsory + + # Unique element to output + MCQ_PRACTICE = "mcq_practice" # Optional + # MCQ_PRACTICE contains these 4 elements + MCQ_PRACTICE_TITLE = "mcq_practice_title" # Compulsory + MCQ_PRACTICE_QUESTION = "mcq_practice_question" # Compulsory + MCQ_PRACTICE_WRONG_OPTIONS = "mcq_practice_wrong_options" # Compulsory + MCQ_PRACTICE_CORRECT_OPTION = "mcq_practice_correct_option" # Compulsory + + # Unique element to output + CODE_PRACTICE = "code_practice" # Optional + # CODE_PRACTICE contains these 3 elements + CODE_PRACTICE_TITLE = "code_practice_title" # Compulsory + CODE_PRACTICE_QUESTION = "code_practice_question" # Compulsory + CODE_PRACTICE_HALF_COMPLETED_CODE = "code_practice_half_completed_code" # Compulsory + CODE_PRACTICE_FULLY_COMPLETED_CODE = "code_practice_fully_completed_code" # Compulsory + CODE_PRACTICE_LANGUAGE = "code_practice_language" # Compulsory + +def get_notes_functions( + contains_mcq_practice: bool, + contains_code_practice: bool +) -> list[dict[str, Any]]: + """Returns the function-calling function that will be passed into the LLM + + Args: + contains_mcq_practice (bool): True if the users want to include multiple-choice questions in the notes. + contains_code_practice (bool): True if the users want to include coding questions in the notes. + + Returns: + list[dict[str, Any]]: The function-calling function that will be passed into the LLM + """ + properties = { + NotesFunctions.TOPIC: { + "type": "string", + "description": "The topic which the revision notes cover in fewer than 7 words." + }, + NotesFunctions.GOAL: { + "type": "string", + "description": "The goal of the revision notes in one sentence. Students should achieve this goal after reading the notes." + }, + NotesFunctions.OVERVIEW: { + "type": "string", + "description": "A high-level summary of the key ideas present in the revision notes in one sentence." + }, + NotesFunctions.KEY_CONCEPTS: { + "type": "array", + "description": "A list of key concepts that students should learn.", + "items": { + "type": "object", + "properties": { + NotesFunctions.KEY_CONCEPT_TITLE: { + "type": "string", + "description": "The title of the key concept." + }, + NotesFunctions.KEY_CONCEPT_EXPLANATION: { + "type": "string", + "description": "State the key concept in one or two sentences." + }, + NotesFunctions.KEY_CONCEPT_CODE_EXAMPLE: { + "type": "object", + "properties": { + NotesFunctions.KEY_CONCEPT_CODE: { + "type": "string", + "description": "The code example illustrating the key concept." + }, + NotesFunctions.KEY_CONCEPT_LANGUAGE: { + "type": "string", + "description": "The programming language of the code example." + } + }, + "required": [NotesFunctions.KEY_CONCEPT_CODE, NotesFunctions.KEY_CONCEPT_LANGUAGE], + } + }, + "required": [NotesFunctions.KEY_CONCEPT_TITLE, NotesFunctions.KEY_CONCEPT_EXPLANATION] + } + }, + NotesFunctions.TIPS: { + "type": "array", + "description": "A list of tips that will help students to apply the key concepts better in the future. Return None if there are no good tips.", + "items": { + "type": "object", + "properties": { + NotesFunctions.TIP_TITLE: { + "type": "string", + "description": "The title of the tip." + }, + NotesFunctions.TIP_EXPLANATION: { + "type": "string", + "description": "State the tip in one or two sentences." + } + }, + "required": [NotesFunctions.TIP_TITLE, NotesFunctions.TIP_EXPLANATION] + } + } + } + + if contains_mcq_practice: + properties[NotesFunctions.MCQ_PRACTICE] = { + "type": "object", + "description": "A multiple-choice question to test students' understanding of the key concepts. Return None if there are no suitable multiple-choice questions.", + "properties": { + NotesFunctions.MCQ_PRACTICE_TITLE: { + "type": "string", + "description": "A short descriptive title for the multiple-choice question." + }, + NotesFunctions.MCQ_PRACTICE_QUESTION: { + "type": "string", + "description": "The multiple-choice question that students have to answer." + }, + NotesFunctions.MCQ_PRACTICE_WRONG_OPTIONS: { + "type": "array", + "description": "A list of wrong options for the multiple-choice question.", + "items": { + "type": "string" + } + }, + NotesFunctions.MCQ_PRACTICE_CORRECT_OPTION: { + "type": "string", + "description": "The correct option for the multiple-choice question.", + }, + }, + "required": [NotesFunctions.MCQ_PRACTICE_TITLE, NotesFunctions.MCQ_PRACTICE_QUESTION, NotesFunctions.MCQ_PRACTICE_WRONG_OPTIONS, NotesFunctions.MCQ_PRACTICE_CORRECT_OPTION] + } + + if contains_code_practice: + properties[NotesFunctions.CODE_PRACTICE] = { + "type": "object", + "description": "A coding question to test students' understanding of the key concepts. Return None if there are no suitable coding questions.", + "properties": { + NotesFunctions.CODE_PRACTICE_TITLE: { + "type": "string", + "description": "A short descriptive title for the coding question." + }, + NotesFunctions.CODE_PRACTICE_QUESTION: { + "type": "string", + "description": "The coding question that is formulated based on the key concepts, with enough context and hints for the student to complete the code without ambiguity." + }, + NotesFunctions.CODE_PRACTICE_HALF_COMPLETED_CODE: { + "type": "string", + "description": "The half-completed code with the TODO marker in place of the missing code." + }, + NotesFunctions.CODE_PRACTICE_FULLY_COMPLETED_CODE: { + "type": "string", + "description": "The fully-completed code, with the missing parts annotated by the TODO marker filled." + }, + NotesFunctions.CODE_PRACTICE_LANGUAGE: { + "type": "string", + "description": "The programming language used in the practice question." + } + }, + "required": [NotesFunctions.CODE_PRACTICE_TITLE, NotesFunctions.CODE_PRACTICE_QUESTION, NotesFunctions.CODE_PRACTICE_HALF_COMPLETED_CODE, NotesFunctions.CODE_PRACTICE_FULLY_COMPLETED_CODE, NotesFunctions.CODE_PRACTICE_LANGUAGE] + } + + notes_functions: list[dict[str, Any]] = [ + { + "name": NotesFunctions.GET_NOTES, + "description": "Generate revision notes based on the key ideas present in the model's response.", + "parameters": { + "type": "object", + "properties": properties, + "required": [ + NotesFunctions.TOPIC, + NotesFunctions.GOAL, + NotesFunctions.OVERVIEW, + NotesFunctions.KEY_CONCEPTS, + NotesFunctions.TIPS, + NotesFunctions.MCQ_PRACTICE, + NotesFunctions.CODE_PRACTICE + ] + } + } + ] + + return notes_functions \ No newline at end of file diff --git a/app/prompts/summariser/google_ai.py b/app/prompts/generator/google_ai.py similarity index 100% rename from app/prompts/summariser/google_ai.py rename to app/prompts/generator/google_ai.py diff --git a/app/prompts/summariser/llama3.py b/app/prompts/generator/llama3.py similarity index 100% rename from app/prompts/summariser/llama3.py rename to app/prompts/generator/llama3.py diff --git a/app/prompts/generator/open_ai.py b/app/prompts/generator/open_ai.py new file mode 100644 index 0000000..fde4acc --- /dev/null +++ b/app/prompts/generator/open_ai.py @@ -0,0 +1,27 @@ +from app.models.conversation import Conversation +from app.process.types import TODO_MARKER + + +def generate_open_ai_summariser_system_message(): + system_message: str = f""" +You are good at generating revision notes from technical conversations and can transform highly specific conversations into transferable software engineering principles. You will be given a conversation between a user and a large language model. The user has asked the model to help him with certain problems he faced while programming. + +Follow these instructions: +1. State the topic which the revision notes cover +2. State the goal of the revision notes and what users should learn after reading through the notes. +3. Provide an overview of the key ideas present in the revision notes. +4. List 2-4 key concepts present in the conversation. Each key concept should have a title, an explanation in one or two sentences. If useful, provide a short code example with appropriate inline comments that illustrate the corresponding key concept and state the programming language of the code. +5. If useful, provide 1-2 tips that will help students to apply the key concepts better in the future. +6. If useful, provide a multiple-choice (MCQ) practice question with 3-4 options that tests the student's understanding of the key concepts. +7. If useful, provide a code practice question that tests the student's understandonf of the key concepts. The code practice question should be a half-completed block of code with 1-3 lines intentionally left blank for the student to fill up. Indicate with a comment '{TODO_MARKER}' in place of the lines of code that are intentionally left blank. You should also provide a fully completed version of the code, which is an exact replica of the half-completed block of code, except that the '{TODO_MARKER}' is now replaced with the actual 1-3 lines of expected code. Give enough hints and context within the question such that the student can complete the code without any ambiguity. +""" + + + return system_message + + +def generate_open_ai_summariser_user_message(conversation: Conversation): + user_message: str = conversation.stringify() + user_message += "\nGenerate revision notes from the model's response but avoid referencing the model or the user in your notes. Describe the content as if it is from a textbook:" + + return user_message diff --git a/app/prompts/summariser/functions.py b/app/prompts/summariser/functions.py deleted file mode 100644 index 9704449..0000000 --- a/app/prompts/summariser/functions.py +++ /dev/null @@ -1,77 +0,0 @@ -from enum import StrEnum -from typing import Any - -class SummaryFunctions(StrEnum): - GET_SUMMARY = "get_summary" - - # Unique element to output - TOPIC = "topic" - GOAL = "goal" - OVERVIEW = "overview" - - KEY_CONCEPTS = "key_concepts" - # List of tuples containing these 3 elements - TITLE = "title" # Compulsory - EXPLANATION = "explanation" # Compulsory - CODE_EXAMPLE = "code_example" # Optional - # CODE_EXAMPLE contains these 2 compulsory elements - CODE = "code" - LANGUAGE = "language" - -def get_summary_functions() -> list[dict[str, Any]]: - summary_functions: list[dict[str, Any]] = [ - { - "name": SummaryFunctions.GET_SUMMARY, - "description": "Generate revision notes based on the key ideas present in the model's response.", - "parameters": { - "type": "object", - "properties": { - SummaryFunctions.TOPIC: { - "type": "string", - "description": "The topic which the revision notes cover in fewer than 7 words." - }, - SummaryFunctions.GOAL: { - "type": "string", - "description": "The goal of the revision notes in one sentence. Students should achieve this goal after reading the notes." - }, - SummaryFunctions.OVERVIEW: { - "type": "string", - "description": "A high-level summary of the key ideas present in the revision notes in one sentence." - }, - SummaryFunctions.KEY_CONCEPTS: { - "type": "array", - "items": { - "type": "object", - "properties": { - SummaryFunctions.TITLE: { - "type": "string", - "description": "The title of the key concept." - }, - SummaryFunctions.EXPLANATION: { - "type": "string", - "description": "State the key concept in one or two sentences." - }, - SummaryFunctions.CODE_EXAMPLE: { - "type": "object", - "properties": { - SummaryFunctions.CODE: { - "type": "string", - "description": "The code example illustrating the key concept." - }, - SummaryFunctions.LANGUAGE: { - "type": "string", - "description": "The programming language of the code example." - } - }, - "required": [SummaryFunctions.CODE, SummaryFunctions.LANGUAGE], - } - }, - "required": [SummaryFunctions.TITLE, SummaryFunctions.EXPLANATION] - } - } - }, - "required": [SummaryFunctions.TOPIC, SummaryFunctions.GOAL, SummaryFunctions.OVERVIEW, SummaryFunctions.KEY_CONCEPTS] - } - } - ] - return summary_functions diff --git a/app/prompts/summariser/open_ai.py b/app/prompts/summariser/open_ai.py deleted file mode 100644 index c0b0af8..0000000 --- a/app/prompts/summariser/open_ai.py +++ /dev/null @@ -1,21 +0,0 @@ -from app.models.conversation import Conversation - - -def generate_open_ai_summariser_system_message(): - system_message: str = """ -You are good at generating revision notes from technical conversations and can transform highly specific conversations into transferable software engineering principles. You will be given a conversation between a user and a large language model. The user has asked the model to help him with certain problems he faced while programming. - -Follow these instructions: -1. State the topic which the revision notes cover -2. State the goal of the revision notes and what users should learn after reading through the notes. -3. Provide an overview of the key ideas present in the revision notes. -4. List 2-4 key concepts present in the conversation. Each key concept should have a title, an explanation in one or two sentences. If useful, provide a short code example with appropriate inline comments that illustrate the corresponding key concept and state the programming language of the code. -""" - return system_message - - -def generate_open_ai_summariser_user_message(conversation: Conversation): - user_message: str = conversation.stringify() - user_message += "\nSummarise the key ideas of the model's response but avoid referencing the model or the user in your summary. Describe the content as if it is from a textbook:" - - return user_message diff --git a/app/scripts/summary.py b/app/scripts/generate.py similarity index 70% rename from app/scripts/summary.py rename to app/scripts/generate.py index d718e57..bc749c9 100644 --- a/app/scripts/summary.py +++ b/app/scripts/generate.py @@ -5,23 +5,25 @@ from app.config import InferenceConfig from app.exceptions.exception import InferenceFailure, LogicError from app.models.conversation import Conversation -from app.process.summariser import Summariser +from app.models.content import Content +from app.process.generator import Generator logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) - -async def generate_summary( - conversations: Union[dict[str, Any] | list[Conversation]], +async def generate( + conversation: Union[dict[str, Any] | list[Conversation]], + content_lst: list[Content], attempt: int = 1, max_attempts: int = 1, token_sum: int = 0, ) -> tuple[list[dict[str, Any]], int]: - """Returns the summary in topic-content key-value pairs and the total token sum of the conversation for usage tracking in stomach. + """Returns the gemerated notes and the total token sum of the conversation for usage tracking in stomach. Args: conversations (Union[dict[str, Any] | list[Conversation]]): The conversation to be summarised. There are two possible types depending on whether it is the first attempt or not. + content_lst (list[Content]): The content types that the user wants to generate notes for. attempt (int, optional): The current attempt number which will be incremented with every retry. Defaults to 1. max_attempts (int, optional): The attempt number which will cause the inference pipeline to stop when reached. Defaults to 9. token_sum (int, optional): The total token sum that is used by the user to generate the notes. Defaults to 0. @@ -31,26 +33,26 @@ async def generate_summary( """ config = InferenceConfig() - summariser = Summariser(config=config) + generator = Generator(config=config) conversation_lst: list[Conversation] = None if attempt == 1: try: - conversation_lst, token_sum = summariser.pre_process( - conversation_dict=conversations + conversation_lst, token_sum = generator.pre_process( + conversation=conversation ) except LogicError as e: log.error(f"Logic error while trying to pre-process conversation: {str(e)}") raise e else: - conversation_lst = conversations + conversation_lst = conversation - summary: list[dict[str, Any]] = [] + notes: list[dict[str, Any]] = [] remaining_conversations: list[Conversation] = [] - summary_tasks = [ - summariser.summarise(conversation=conversation) for conversation in conversation_lst + generate_tasks = [ + generator.generate(conversation=conversation, content_lst=content_lst) for conversation in conversation_lst ] - results = await asyncio.gather(*summary_tasks, return_exceptions=True) + results = await asyncio.gather(*generate_tasks, return_exceptions=True) for i, result in enumerate(results): if isinstance(result, Exception): @@ -60,14 +62,14 @@ async def generate_summary( ) remaining_conversations.append(conversation_lst[i]) else: - summary.append(result) + notes.append(result) if remaining_conversations and attempt < max_attempts: log.info( f"Retrying summary generation for remaining {len(remaining_conversations)} conversations..." ) - return await generate_summary( - conversations=remaining_conversations, + return await generate( + conversation=remaining_conversations, attempt=attempt + 1, max_attempts=max_attempts, token_sum=token_sum, @@ -77,4 +79,4 @@ async def generate_summary( f"Failed to post-process remaining {len(remaining_conversations)} conversations after {max_attempts} attempts." ) - return summary, token_sum + return notes, token_sum diff --git a/app/scripts/practice.py b/app/scripts/practice.py index 1d85ff9..cf025a4 100644 --- a/app/scripts/practice.py +++ b/app/scripts/practice.py @@ -4,7 +4,7 @@ from app.config import InferenceConfig from app.exceptions.exception import InferenceFailure, LogicError from app.process.examiner import Examiner -from app.prompts.summariser.functions import SummaryFunctions +from app.prompts.generator.functions import NotesFunctions log = logging.getLogger(__name__) @@ -60,7 +60,7 @@ async def _generate( examiner = Examiner(config=config) try: language, question, half_completed_code, fully_completed_code = await examiner.examine( - topic=summary_chunk[SummaryFunctions.TOPIC.value], + topic=summary_chunk[NotesFunctions.TOPIC.value], summary_chunk=summary_chunk ) diff --git a/tests/control/post/test_summariser.py b/tests/control/post/test_summariser.py index 08c5771..67d67c7 100644 --- a/tests/control/post/test_summariser.py +++ b/tests/control/post/test_summariser.py @@ -1,6 +1,6 @@ import pytest -from app.control.post.summariser import ( +from app.control.post.generator import ( _reject_unlikely_topics, post_process ) diff --git a/tests/control/pre/test_summariser.py b/tests/control/pre/test_summariser.py index c4d2874..e0c671f 100644 --- a/tests/control/pre/test_summariser.py +++ b/tests/control/pre/test_summariser.py @@ -2,7 +2,7 @@ import pytest -from app.control.pre.summariser import _split_by_token_length, pre_process +from app.control.pre.generator import _split_by_token_length, pre_process from app.exceptions.exception import LogicError from app.models.conversation import Conversation @@ -50,7 +50,7 @@ def test_pre_process( return_value=mock_tokenizer, ): result, token_sum = pre_process( - conversation_dict=valid_conversation_dict, max_input_tokens=max_input_tokens + conversation=valid_conversation_dict, max_input_tokens=max_input_tokens ) assert mock_tokenizer.call_count == MOCK_TOKENIZER_CALL_COUNT assert len(result) == expected_number_of_splits @@ -60,7 +60,7 @@ def test_pre_process( def test_pre_process_with_invalid_input(invalid_conversation_dict): with pytest.raises(LogicError): - pre_process(conversation_dict=invalid_conversation_dict, max_input_tokens=100) + pre_process(conversation=invalid_conversation_dict, max_input_tokens=100) @pytest.mark.parametrize( diff --git a/tests/models/test_inference.py b/tests/models/test_inference.py index e15393f..f7c9fdf 100644 --- a/tests/models/test_inference.py +++ b/tests/models/test_inference.py @@ -2,7 +2,7 @@ from pydantic import ValidationError from app.models.inference import InferenceInput -from app.models.task import Task +from app.models.content import Task INFERENCE_INPUT_VALID_DATA = [ ( diff --git a/tests/models/test_task.py b/tests/models/test_task.py index e10a6c7..01d90ea 100644 --- a/tests/models/test_task.py +++ b/tests/models/test_task.py @@ -1,6 +1,6 @@ import pytest -from app.models.task import Task +from app.models.content import Task def test_enum_values(): diff --git a/tests/test_config.py b/tests/test_config.py index bc910e9..7ea2c81 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,7 +2,7 @@ from app.config import InferenceConfig from app.llm.model import LLMType -from app.models.task import Task +from app.models.content import Task @pytest.fixture def config():