Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyang628 committed May 18, 2024
1 parent 9642400 commit 5af4a8b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def generate_notes(input: InferenceInput) -> JSONResponse:
tasks: list[str] = input.tasks
validated_tasks: list[Task] = Task.validate(tasks)

summary: Optional[dict[str, str]] = None
summary: Optional[list[dict[str, str]]] = None
practice: Optional[list[dict[str, str]]] = None
for task in validated_tasks:
if task == Task.SUMMARISE:
Expand Down
16 changes: 10 additions & 6 deletions app/scripts/practice.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import logging
import asyncio
from typing import Any
from app.config import InferenceConfig
from app.exceptions.exception import InferenceFailure, LogicError
from app.process.examiner import Examiner
from app.prompts.summariser.functions import SummaryFunctions

log = logging.getLogger(__name__)


async def generate_practice(
summary: dict[str, str]
summary: list[dict[str, Any]]
) -> list[dict[str, str]]:

tasks = [
_generate(topic, summary_chunk)
for topic, summary_chunk in summary.items()
_generate(summary_chunk) for summary_chunk in summary
]

results = await asyncio.gather(*tasks, return_exceptions=True)
Expand All @@ -28,7 +29,6 @@ async def generate_practice(
continue
practice.append(
{
"summary_chunk": summary_chunk,
"language": result[0],
"question": result[1],
"half_completed_code": result[2],
Expand All @@ -41,7 +41,9 @@ async def generate_practice(
return practice

async def _generate(
topic: str, summary_chunk: str, attempt=1, max_attempts=9
summary_chunk: dict[str, Any],
attempt=1,
max_attempts=1
) -> tuple[str, str, str]:
"""Generates a practice question and answer for a given topic and summary chunk.
Expand All @@ -58,7 +60,9 @@ async def _generate(
examiner = Examiner(config=config)
try:
language, question, half_completed_code, fully_completed_code = await examiner.examine(
topic=topic, summary_chunk=summary_chunk
topic=summary_chunk[SummaryFunctions.TOPIC.value],

summary_chunk=summary_chunk
)
return language, question, half_completed_code, fully_completed_code
except LogicError as e:
Expand Down
12 changes: 6 additions & 6 deletions app/scripts/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def generate_summary(
attempt: int = 1,
max_attempts: int = 1,
token_sum: int = 0,
) -> tuple[dict[str, Any], int]:
) -> tuple[list[dict[str, Any]], int]:
"""Returns the summary in topic-content key-value pairs and the total token sum of the conversation for usage tracking in stomach.
Args:
Expand Down Expand Up @@ -45,22 +45,22 @@ async def generate_summary(
else:
conversation_lst = conversations

merged_summary: dict[str, Any] = {}
summary: list[dict[str, Any]] = []
remaining_conversations: list[Conversation] = []
summary_tasks = [
summariser.summarise(conversation=conversation) for conversation in conversation_lst
]
summaries = await asyncio.gather(*summary_tasks, return_exceptions=True)
summary = await asyncio.gather(*summary_tasks, return_exceptions=True)

for i, result in enumerate(summaries):
for i, result in enumerate(summary):
if isinstance(result, Exception):
# TODO: Handle exceptions individually
log.error(
f"Error processing conversation {i+1} (attempt {attempt}/{max_attempts}): {result}"
)
remaining_conversations.append(conversation_lst[i])
else:
merged_summary.update(result)
summary.append(result)

if remaining_conversations and attempt < max_attempts:
log.info(
Expand All @@ -77,4 +77,4 @@ async def generate_summary(
f"Failed to post-process remaining {len(remaining_conversations)} conversations after {max_attempts} attempts."
)

return merged_summary, token_sum
return summary, token_sum

0 comments on commit 5af4a8b

Please sign in to comment.