Skip to content

Commit

Permalink
add unique constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Dec 15, 2024
1 parent 1b9db0d commit 0b6f570
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
1 change: 1 addition & 0 deletions backend/alembic/versions/91a0a4d62b14_milestone.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def upgrade() -> None:
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
)


Expand Down
43 changes: 22 additions & 21 deletions backend/onyx/db/milestone.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified

Expand Down Expand Up @@ -27,28 +28,28 @@ def create_milestone(


def create_milestone_if_not_exists(
user: User | None,
event_type: MilestoneRecordType,
db_session: Session,
user: User | None, event_type: MilestoneRecordType, db_session: Session
) -> tuple[Milestone, bool]:
"""
Create a milestone if it doesn't already exist.
Returns the milestone and a boolean indicating if it was created.
"""
# Every milestone should only happen once per deployment/tenant
stmt = select(Milestone).where(
Milestone.event_type == event_type,
)
result = db_session.execute(stmt)
milestones = result.scalars().all()

if len(milestones) > 1:
raise ValueError(f"Multiple {event_type} milestones found")

if not milestones:
return create_milestone(user, event_type, db_session), True

return milestones[0], False
# Check if it exists
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one_or_none()

if milestone is not None:
return milestone, False

# If it doesn't exist, try to create it.
try:
milestone = create_milestone(user, event_type, db_session)
return milestone, True
except IntegrityError:
# Another thread or process inserted it in the meantime
db_session.rollback()
# Fetch again to return the existing record
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one() # Now should exist
return milestone, False


def update_user_assistant_milestone(
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,8 @@ class Milestone(Base):

user: Mapped[User | None] = relationship("User")

__table_args__ = (UniqueConstraint("event_type", name="uq_milestone_event_type"),)


class TaskQueueState(Base):
# Currently refers to Celery Tasks
Expand Down

0 comments on commit 0b6f570

Please sign in to comment.