Skip to content

Commit

Permalink
feat: ET-1594: add support for retrieving all prompts by category
Browse files Browse the repository at this point in the history
  • Loading branch information
gethin-dvla committed Oct 21, 2024
1 parent 4ef83f1 commit d6ee77e
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
7 changes: 5 additions & 2 deletions lab_gen/services/conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(self, app: FastAPI, examples: dict[str, list[str]], prompts: dict[s
self.app = app
self.examples = examples
self.prompts = prompts
self.all_prompts = {}
for key, value in self.prompts.items():
self.all_prompts[key] = value.input_variables

def create_chain(self, llm: BaseLanguageModel, prompt: ChatPromptTemplate) -> RunnableWithMessageHistory:
"""
Expand Down Expand Up @@ -291,9 +294,9 @@ def end(self, user_id: str, conversation_id: str) -> None:
return history.clear()
raise NoConversationError(conversation_id)

def get_prompts(self) -> dict[str, list[str]]:
def get_prompts(self, categories: str) -> dict[str, list[str]]:
"""Gets the example prompts configured for this service."""
return self.examples
return self.all_prompts if categories else self.examples

def get_prompt(self, prompt_id: str) -> StringPromptTemplate:
"""Gets the prompt template for the given prompt ID."""
Expand Down
File renamed without changes.
File renamed without changes.
5 changes: 3 additions & 2 deletions lab_gen/web/api/prompts/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, Query
from fastapi.routing import APIRouter

from lab_gen.services.conversation.conversation import ConversationService
Expand All @@ -13,13 +13,14 @@ async def read_prompts(
*,
api_key: bool = Depends(get_api_key), # noqa: ARG001
conversation: ConversationService = Depends(conversation_provider), # noqa: B008
show: str | None = Query(default=None, description="The category of prompts to show."),
) -> dict[str, list[str]]:
"""Returns available prompts.
Returns:
Mapping of prompt names to prompt texts.
"""
return conversation.get_prompts()
return conversation.get_prompts(show)

@router.get("/prompts/{prompt_id}")
async def read_prompt( # noqa: D417
Expand Down

0 comments on commit d6ee77e

Please sign in to comment.