From 970320bd49201ab5705dd4bd02aa3f50d35dab66 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Sun, 8 Dec 2024 19:39:59 -0800 Subject: [PATCH] Persona / prompt hardening (#3375) * Persona / prompt hardening * fix it --- backend/danswer/db/persona.py | 88 +++++++++++-------- backend/danswer/seeding/load_yamls.py | 19 ++-- backend/ee/danswer/server/seeding.py | 11 ++- .../common_utils/managers/persona.py | 2 +- 4 files changed, 74 insertions(+), 46 deletions(-) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index f8602ed5af7..ee97885b376 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -453,9 +453,9 @@ def upsert_persona( """ if persona_id is not None: - persona = db_session.query(Persona).filter_by(id=persona_id).first() + existing_persona = db_session.query(Persona).filter_by(id=persona_id).first() else: - persona = _get_persona_by_name( + existing_persona = _get_persona_by_name( persona_name=name, user=user, db_session=db_session ) @@ -481,62 +481,78 @@ def upsert_persona( prompts = None if prompt_ids is not None: prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all() - if not prompts and prompt_ids: - raise ValueError("prompts not found") + + if prompts is not None and len(prompts) == 0: + raise ValueError( + f"Invalid Persona config, no valid prompts " + f"specified. Specified IDs were: '{prompt_ids}'" + ) # ensure all specified tools are valid if tools: validate_persona_tools(tools) - if persona: + if existing_persona: # Built-in personas can only be updated through YAML configuration. # This ensures that core system personas are not modified unintentionally. - if persona.builtin_persona and not builtin_persona: + if existing_persona.builtin_persona and not builtin_persona: raise ValueError("Cannot update builtin persona with non-builtin.") # this checks if the user has permission to edit the persona - persona = fetch_persona_by_id( - db_session=db_session, persona_id=persona.id, user=user, get_editable=True + # will raise an Exception if the user does not have permission + existing_persona = fetch_persona_by_id( + db_session=db_session, + persona_id=existing_persona.id, + user=user, + get_editable=True, ) # The following update excludes `default`, `built-in`, and display priority. # Display priority is handled separately in the `display-priority` endpoint. # `default` and `built-in` properties can only be set when creating a persona. - persona.name = name - persona.description = description - persona.num_chunks = num_chunks - persona.chunks_above = chunks_above - persona.chunks_below = chunks_below - persona.llm_relevance_filter = llm_relevance_filter - persona.llm_filter_extraction = llm_filter_extraction - persona.recency_bias = recency_bias - persona.llm_model_provider_override = llm_model_provider_override - persona.llm_model_version_override = llm_model_version_override - persona.starter_messages = starter_messages - persona.deleted = False # Un-delete if previously deleted - persona.is_public = is_public - persona.icon_color = icon_color - persona.icon_shape = icon_shape + existing_persona.name = name + existing_persona.description = description + existing_persona.num_chunks = num_chunks + existing_persona.chunks_above = chunks_above + existing_persona.chunks_below = chunks_below + existing_persona.llm_relevance_filter = llm_relevance_filter + existing_persona.llm_filter_extraction = llm_filter_extraction + existing_persona.recency_bias = recency_bias + existing_persona.llm_model_provider_override = llm_model_provider_override + existing_persona.llm_model_version_override = llm_model_version_override + existing_persona.starter_messages = starter_messages + existing_persona.deleted = False # Un-delete if previously deleted + existing_persona.is_public = is_public + existing_persona.icon_color = icon_color + existing_persona.icon_shape = icon_shape if remove_image or uploaded_image_id: - persona.uploaded_image_id = uploaded_image_id - persona.is_visible = is_visible - persona.search_start_date = search_start_date - persona.category_id = category_id + existing_persona.uploaded_image_id = uploaded_image_id + existing_persona.is_visible = is_visible + existing_persona.search_start_date = search_start_date + existing_persona.category_id = category_id # Do not delete any associations manually added unless # a new updated list is provided if document_sets is not None: - persona.document_sets.clear() - persona.document_sets = document_sets or [] + existing_persona.document_sets.clear() + existing_persona.document_sets = document_sets or [] if prompts is not None: - persona.prompts.clear() - persona.prompts = prompts or [] + existing_persona.prompts.clear() + existing_persona.prompts = prompts if tools is not None: - persona.tools = tools or [] + existing_persona.tools = tools or [] + + persona = existing_persona else: - persona = Persona( + if not prompts: + raise ValueError( + "Invalid Persona config. " + "Must specify at least one prompt for a new persona." + ) + + new_persona = Persona( id=persona_id, user_id=user.id if user else None, is_public=is_public, @@ -549,7 +565,7 @@ def upsert_persona( llm_filter_extraction=llm_filter_extraction, recency_bias=recency_bias, builtin_persona=builtin_persona, - prompts=prompts or [], + prompts=prompts, document_sets=document_sets or [], llm_model_provider_override=llm_model_provider_override, llm_model_version_override=llm_model_version_override, @@ -564,8 +580,8 @@ def upsert_persona( is_default_persona=is_default_persona, category_id=category_id, ) - db_session.add(persona) - + db_session.add(new_persona) + persona = new_persona if commit: db_session.commit() else: diff --git a/backend/danswer/seeding/load_yamls.py b/backend/danswer/seeding/load_yamls.py index c93851dd5c0..7d589dfbee8 100644 --- a/backend/danswer/seeding/load_yamls.py +++ b/backend/danswer/seeding/load_yamls.py @@ -79,6 +79,9 @@ def load_personas_from_yaml( if prompts: prompt_ids = [prompt.id for prompt in prompts if prompt is not None] + if not prompt_ids: + raise ValueError("Invalid Persona config, no prompts exist") + p_id = persona.get("id") tool_ids = [] @@ -123,12 +126,16 @@ def load_personas_from_yaml( tool_ids=tool_ids, builtin_persona=True, is_public=True, - display_priority=existing_persona.display_priority - if existing_persona is not None - else persona.get("display_priority"), - is_visible=existing_persona.is_visible - if existing_persona is not None - else persona.get("is_visible"), + display_priority=( + existing_persona.display_priority + if existing_persona is not None + else persona.get("display_priority") + ), + is_visible=( + existing_persona.is_visible + if existing_persona is not None + else persona.get("is_visible") + ), db_session=db_session, ) diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index b4920815a83..f1081fe5f37 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -132,13 +132,18 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> if personas: logger.notice("Seeding Personas") for persona in personas: + if not persona.prompt_ids: + raise ValueError( + f"Invalid Persona with name {persona.name}; no prompts exist" + ) + upsert_persona( user=None, # Seeding is done as admin name=persona.name, description=persona.description, - num_chunks=persona.num_chunks - if persona.num_chunks is not None - else 0.0, + num_chunks=( + persona.num_chunks if persona.num_chunks is not None else 0.0 + ), llm_relevance_filter=persona.llm_relevance_filter, llm_filter_extraction=persona.llm_filter_extraction, recency_bias=RecencyBiasSetting.AUTO, diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py index de2d9db25c1..e5392dfb68b 100644 --- a/backend/tests/integration/common_utils/managers/persona.py +++ b/backend/tests/integration/common_utils/managers/persona.py @@ -42,7 +42,7 @@ def create( "is_public": is_public, "llm_filter_extraction": llm_filter_extraction, "recency_bias": recency_bias, - "prompt_ids": prompt_ids or [], + "prompt_ids": prompt_ids or [0], "document_set_ids": document_set_ids or [], "tool_ids": tool_ids or [], "llm_model_provider_override": llm_model_provider_override,