Skip to content

Commit

Permalink
add initial multi tenancy support
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Sep 27, 2024
1 parent 66a4592 commit 4896ad9
Show file tree
Hide file tree
Showing 101 changed files with 2,276 additions and 744 deletions.
79 changes: 48 additions & 31 deletions backend/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import asyncio
from logging.config import fileConfig

from typing import Tuple
from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from sqlalchemy import pool
from sqlalchemy import pool, text
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
from sqlalchemy.schema import SchemaItem

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
# Alembic Config object
config = context.config

# Interpret the config file for Python logging.
Expand All @@ -21,16 +21,22 @@
):
fileConfig(config.config_file_name)

# add your model's MetaData object here
# Add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]

# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def get_schema_options() -> str:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(','):
if '=' in pair:
key, value = pair.split('=', 1)
x_args[key] = value
schema_name = x_args.get('schema', 'public')
return schema_name

EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}

Expand All @@ -46,48 +52,62 @@ def include_object(
return False
return True

EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}

def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.

This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True

Calls to context.execute() here emit the given string to the
script output.

"""
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
url = build_connection_string()
schema = get_schema_options()


context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
include_schemas=True,
)

with context.begin_transaction():
context.run_migrations()


def do_run_migrations(connection: Connection) -> None:
schema = get_schema_options()

connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text('COMMIT'))

connection.execute(text(f'SET search_path TO "{schema}"'))

context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
) # type: ignore
target_metadata=target_metadata, # type: ignore
version_table_schema=schema,
include_schemas=True,
compare_type=True,
compare_server_default=True,
)

with context.begin_transaction():
context.run_migrations()


async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""

print("Running async migrations")
"""Run migrations in 'online' mode."""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
Expand All @@ -98,13 +118,10 @@ async def run_async_migrations() -> None:

await connectable.dispose()


def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""

asyncio.run(run_async_migrations())


if context.is_offline_mode():
run_migrations_offline()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import sqlalchemy as sa
from sqlalchemy.sql import table
from sqlalchemy.dialects import postgresql
from alembic_utils import encrypt_string
import json

from danswer.utils.encryption import encrypt_string_to_bytes

# revision identifiers, used by Alembic.
revision = "0a98909f2757"
Expand Down Expand Up @@ -57,7 +57,7 @@ def upgrade() -> None:
# In other words, this upgrade does not apply the encryption. Porting existing sensitive data
# and key rotation currently is not supported and will come out in the future
for row_id, creds, _ in results:
creds_binary = encrypt_string_to_bytes(json.dumps(creds))
creds_binary = encrypt_string(json.dumps(creds))
connection.execute(
creds_table.update()
.where(creds_table.c.id == row_id)
Expand Down Expand Up @@ -86,7 +86,7 @@ def upgrade() -> None:
results = connection.execute(sa.select(llm_table))

for row_id, api_key, _ in results:
llm_key = encrypt_string_to_bytes(api_key)
llm_key = encrypt_string(api_key)
connection.execute(
llm_table.update()
.where(llm_table.c.id == row_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from alembic import op
import sqlalchemy as sa

from danswer.configs.constants import DocumentSource
from alembic_utils import DocumentSource

# revision identifiers, used by Alembic.
revision = "15326fcec57e"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from alembic_utils import NUM_POSTPROCESSED_RESULTS

# revision identifiers, used by Alembic.
revision = "1f60f60c3401"
Expand Down
27 changes: 0 additions & 27 deletions backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
Create Date: 2024-04-15 01:36:02.952809
"""
import json
from typing import cast
from alembic import op
import sqlalchemy as sa
from danswer.dynamic_configs.factory import get_dynamic_config_store

# revision identifiers, used by Alembic.
revision = "703313b75876"
Expand Down Expand Up @@ -53,30 +50,6 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
)

try:
settings_json = cast(
str, get_dynamic_config_store().load("token_budget_settings")
)
settings = json.loads(settings_json)

is_enabled = settings.get("enable_token_budget", False)
token_budget = settings.get("token_budget", -1)
period_hours = settings.get("period_hours", -1)

if is_enabled and token_budget > 0 and period_hours > 0:
op.execute(
f"INSERT INTO token_rate_limit \
(enabled, token_budget, period_hours, scope) VALUES \
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
)

# Delete the dynamic config
get_dynamic_config_store().delete("token_budget_settings")

except Exception:
# Ignore if the dynamic config is not found
pass


def downgrade() -> None:
op.drop_table("token_rate_limit__user_group")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
"""
from alembic import op
import sqlalchemy as sa
from alembic_utils import IndexModelStatus, RecencyBiasSetting, SearchType

from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType

# revision identifiers, used by Alembic.
revision = "776b3bbe9092"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
from alembic import op
import sqlalchemy as sa
from danswer.configs.constants import DocumentSource
from alembic_utils import DocumentSource

# revision identifiers, used by Alembic.
revision = "91fd3b470d1a"
Expand Down
2 changes: 1 addition & 1 deletion backend/alembic/versions/b156fa702355_chat_reworked.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import ENUM
from danswer.configs.constants import DocumentSource
from alembic_utils import DocumentSource

# revision identifiers, used by Alembic.
revision = "b156fa702355"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""add tenant id to user model
Revision ID: b25c363470f3
Revises: 1f60f60c3401
Create Date: 2024-08-29 17:03:20.794120
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "b25c363470f3"
down_revision = "1f60f60c3401"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column("user", sa.Column("tenant_id", sa.Text(), nullable=True))


def downgrade() -> None:
op.drop_column("user", "tenant_id")
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@

def upgrade() -> None:
conn = op.get_bind()

existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('SELECT id, chosen_assistants FROM "user"')
)
op.drop_column(
"user",
'user',
"chosen_assistants",
)
op.add_column(
"user",
'user',
sa.Column(
"chosen_assistants",
postgresql.JSONB(astext_type=sa.Text()),
Expand All @@ -37,7 +38,7 @@ def upgrade() -> None:
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
Expand All @@ -46,20 +47,20 @@ def upgrade() -> None:
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('SELECT id, chosen_assistants FROM user')
)
op.drop_column(
"user",
'user',
"chosen_assistants",
)
op.add_column(
"user",
'user',
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
),
{"chosen_assistants": chosen_assistants, "id": id},
)
Loading

0 comments on commit 4896ad9

Please sign in to comment.