Skip to content

Commit

Permalink
Group fk fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Sep 30, 2024
1 parent fc45354 commit d724da9
Showing 1 changed file with 111 additions and 101 deletions.
212 changes: 111 additions & 101 deletions backend/ee/danswer/db/user_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import DocumentSet__UserGroup
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import Persona__UserGroup
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
Expand All @@ -33,6 +34,93 @@
logger = setup_logger()


def _cleanup_user__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)

user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)


def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)


def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)


def _cleanup_persona__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)


def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)


def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)


def _cleanup_document_set__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.execute(
delete(DocumentSet__UserGroup).where(
DocumentSet__UserGroup.user_group_id == user_group_id
)
)


def validate_user_creation_permissions(
db_session: Session,
user: User | None,
Expand Down Expand Up @@ -286,42 +374,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
return db_user_group


def _cleanup_user__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)

user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)


def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)


def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)


def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
Expand Down Expand Up @@ -476,21 +528,6 @@ def update_user_group(
return db_user_group


def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)


def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
Expand All @@ -499,16 +536,31 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->

_check_user_group_is_modifiable(db_user_group)

_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)

_cleanup_credential__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
_cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_token_rate_limit__user_group_relationships__no_commit(
_cleanup_document_set__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_persona__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group_id,
outdated_only=False,
)
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)

Expand All @@ -517,31 +569,12 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
db_session.commit()


def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)


def _cleanup_document_set__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.execute(
delete(DocumentSet__UserGroup).where(
DocumentSet__UserGroup.user_group_id == user_group_id
)
)
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
"""
This assumes that all the fk cleanup has already been done.
"""
db_session.delete(user_group)
db_session.commit()


def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None:
Expand All @@ -553,29 +586,6 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
db_session.commit()


def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group.id,
outdated_only=False,
)
_cleanup_document_set__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)

# need to flush so that we don't get a foreign key error when deleting the user group row
db_session.flush()

db_session.delete(user_group)
db_session.commit()


def delete_user_group_cc_pair_relationship__no_commit(
cc_pair_id: int, db_session: Session
) -> None:
Expand Down

0 comments on commit d724da9

Please sign in to comment.