diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index 529112004b0..00ad159c146 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -18,6 +18,7 @@ from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.models import DocumentSet__UserGroup from danswer.db.models import LLMProvider__UserGroup +from danswer.db.models import Persona__UserGroup from danswer.db.models import TokenRateLimit__UserGroup from danswer.db.models import User from danswer.db.models import User__UserGroup @@ -33,6 +34,93 @@ logger = setup_logger() +def _cleanup_user__user_group_relationships__no_commit( + db_session: Session, + user_group_id: int, + user_ids: list[UUID] | None = None, +) -> None: + """NOTE: does not commit the transaction.""" + where_clause = User__UserGroup.user_group_id == user_group_id + if user_ids: + where_clause &= User__UserGroup.user_id.in_(user_ids) + + user__user_group_relationships = db_session.scalars( + select(User__UserGroup).where(where_clause) + ).all() + for user__user_group_relationship in user__user_group_relationships: + db_session.delete(user__user_group_relationship) + + +def _cleanup_credential__user_group_relationships__no_commit( + db_session: Session, + user_group_id: int, +) -> None: + """NOTE: does not commit the transaction.""" + db_session.query(Credential__UserGroup).filter( + Credential__UserGroup.user_group_id == user_group_id + ).delete(synchronize_session=False) + + +def _cleanup_llm_provider__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + db_session.query(LLMProvider__UserGroup).filter( + LLMProvider__UserGroup.user_group_id == user_group_id + ).delete(synchronize_session=False) + + +def _cleanup_persona__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + db_session.query(Persona__UserGroup).filter( + Persona__UserGroup.user_group_id == user_group_id + ).delete(synchronize_session=False) + + +def _cleanup_token_rate_limit__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + token_rate_limit__user_group_relationships = db_session.scalars( + select(TokenRateLimit__UserGroup).where( + TokenRateLimit__UserGroup.user_group_id == user_group_id + ) + ).all() + for ( + token_rate_limit__user_group_relationship + ) in token_rate_limit__user_group_relationships: + db_session.delete(token_rate_limit__user_group_relationship) + + +def _cleanup_user_group__cc_pair_relationships__no_commit( + db_session: Session, user_group_id: int, outdated_only: bool +) -> None: + """NOTE: does not commit the transaction.""" + stmt = select(UserGroup__ConnectorCredentialPair).where( + UserGroup__ConnectorCredentialPair.user_group_id == user_group_id + ) + if outdated_only: + stmt = stmt.where( + UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712 + ) + user_group__cc_pair_relationships = db_session.scalars(stmt) + for user_group__cc_pair_relationship in user_group__cc_pair_relationships: + db_session.delete(user_group__cc_pair_relationship) + + +def _cleanup_document_set__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + db_session.execute( + delete(DocumentSet__UserGroup).where( + DocumentSet__UserGroup.user_group_id == user_group_id + ) + ) + + def validate_user_creation_permissions( db_session: Session, user: User | None, @@ -286,42 +374,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG return db_user_group -def _cleanup_user__user_group_relationships__no_commit( - db_session: Session, - user_group_id: int, - user_ids: list[UUID] | None = None, -) -> None: - """NOTE: does not commit the transaction.""" - where_clause = User__UserGroup.user_group_id == user_group_id - if user_ids: - where_clause &= User__UserGroup.user_id.in_(user_ids) - - user__user_group_relationships = db_session.scalars( - select(User__UserGroup).where(where_clause) - ).all() - for user__user_group_relationship in user__user_group_relationships: - db_session.delete(user__user_group_relationship) - - -def _cleanup_credential__user_group_relationships__no_commit( - db_session: Session, - user_group_id: int, -) -> None: - """NOTE: does not commit the transaction.""" - db_session.query(Credential__UserGroup).filter( - Credential__UserGroup.user_group_id == user_group_id - ).delete(synchronize_session=False) - - -def _cleanup_llm_provider__user_group_relationships__no_commit( - db_session: Session, user_group_id: int -) -> None: - """NOTE: does not commit the transaction.""" - db_session.query(LLMProvider__UserGroup).filter( - LLMProvider__UserGroup.user_group_id == user_group_id - ).delete(synchronize_session=False) - - def _mark_user_group__cc_pair_relationships_outdated__no_commit( db_session: Session, user_group_id: int ) -> None: @@ -476,21 +528,6 @@ def update_user_group( return db_user_group -def _cleanup_token_rate_limit__user_group_relationships__no_commit( - db_session: Session, user_group_id: int -) -> None: - """NOTE: does not commit the transaction.""" - token_rate_limit__user_group_relationships = db_session.scalars( - select(TokenRateLimit__UserGroup).where( - TokenRateLimit__UserGroup.user_group_id == user_group_id - ) - ).all() - for ( - token_rate_limit__user_group_relationship - ) in token_rate_limit__user_group_relationships: - db_session.delete(token_rate_limit__user_group_relationship) - - def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None: stmt = select(UserGroup).where(UserGroup.id == user_group_id) db_user_group = db_session.scalar(stmt) @@ -499,16 +536,31 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> _check_user_group_is_modifiable(db_user_group) + _mark_user_group__cc_pair_relationships_outdated__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _cleanup_credential__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) _cleanup_user__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) - _mark_user_group__cc_pair_relationships_outdated__no_commit( + _cleanup_token_rate_limit__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) - _cleanup_token_rate_limit__user_group_relationships__no_commit( + _cleanup_document_set__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _cleanup_persona__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _cleanup_user_group__cc_pair_relationships__no_commit( + db_session=db_session, + user_group_id=user_group_id, + outdated_only=False, + ) + _cleanup_llm_provider__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) @@ -517,31 +569,12 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> db_session.commit() -def _cleanup_user_group__cc_pair_relationships__no_commit( - db_session: Session, user_group_id: int, outdated_only: bool -) -> None: - """NOTE: does not commit the transaction.""" - stmt = select(UserGroup__ConnectorCredentialPair).where( - UserGroup__ConnectorCredentialPair.user_group_id == user_group_id - ) - if outdated_only: - stmt = stmt.where( - UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712 - ) - user_group__cc_pair_relationships = db_session.scalars(stmt) - for user_group__cc_pair_relationship in user_group__cc_pair_relationships: - db_session.delete(user_group__cc_pair_relationship) - - -def _cleanup_document_set__user_group_relationships__no_commit( - db_session: Session, user_group_id: int -) -> None: - """NOTE: does not commit the transaction.""" - db_session.execute( - delete(DocumentSet__UserGroup).where( - DocumentSet__UserGroup.user_group_id == user_group_id - ) - ) +def delete_user_group(db_session: Session, user_group: UserGroup) -> None: + """ + This assumes that all the fk cleanup has already been done. + """ + db_session.delete(user_group) + db_session.commit() def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None: @@ -553,29 +586,6 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non db_session.commit() -def delete_user_group(db_session: Session, user_group: UserGroup) -> None: - _cleanup_llm_provider__user_group_relationships__no_commit( - db_session=db_session, user_group_id=user_group.id - ) - _cleanup_user__user_group_relationships__no_commit( - db_session=db_session, user_group_id=user_group.id - ) - _cleanup_user_group__cc_pair_relationships__no_commit( - db_session=db_session, - user_group_id=user_group.id, - outdated_only=False, - ) - _cleanup_document_set__user_group_relationships__no_commit( - db_session=db_session, user_group_id=user_group.id - ) - - # need to flush so that we don't get a foreign key error when deleting the user group row - db_session.flush() - - db_session.delete(user_group) - db_session.commit() - - def delete_user_group_cc_pair_relationship__no_commit( cc_pair_id: int, db_session: Session ) -> None: