Skip to content

Commit

Permalink
Proper tenant reset (#3015)
Browse files Browse the repository at this point in the history
* add proper tenant reset

* clear comment

* minor formatting
  • Loading branch information
pablonyx authored Oct 31, 2024
1 parent add87fa commit 0b08bf4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 24 deletions.
58 changes: 35 additions & 23 deletions backend/danswer/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,18 @@ async def get_async_session_with_tenant(
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
"""
Generate a database session bound to a connection with the appropriate tenant schema set.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed.
"""
engine = get_sqlalchemy_engine()

# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()

if tenant_id is None:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = previous_tenant_id
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)

Expand All @@ -335,30 +342,35 @@ def get_session_with_tenant(
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")

# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection

# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection

# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
yield session
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
cursor.close()

# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()

finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)


def set_search_path_on_checkout(
Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def bulk_invite_users(
)

tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()

normalized_emails = []
try:
for email in emails:
Expand All @@ -206,13 +205,16 @@ def bulk_invite_users(
if MULTI_TENANT:
try:
add_users_to_tenant(normalized_emails, tenant_id)

except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise HTTPException(
status_code=400,
detail="User has already been invited to a Danswer organization",
)
raise
except Exception as e:
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")

initial_invited_users = get_invited_users()

Expand Down

0 comments on commit 0b08bf4

Please sign in to comment.