Skip to content

Commit

Permalink
run pretty + black
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Sep 27, 2024
1 parent a124fbc commit 7a0ba35
Show file tree
Hide file tree
Showing 58 changed files with 305 additions and 244 deletions.
22 changes: 14 additions & 8 deletions backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@
# target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]


def get_schema_options() -> str:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(','):
if '=' in pair:
key, value = pair.split('=', 1)
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key] = value
schema_name = x_args.get('schema', 'public')
schema_name = x_args.get("schema", "public")
return schema_name


EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}


def include_object(
object: SchemaItem,
name: str,
Expand All @@ -56,10 +59,9 @@ def run_migrations_offline() -> None:
url = build_connection_string()
schema = get_schema_options()


context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
Expand All @@ -69,17 +71,18 @@ def run_migrations_offline() -> None:
with context.begin_transaction():
context.run_migrations()


def do_run_migrations(connection: Connection) -> None:
schema = get_schema_options()

connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text('COMMIT'))
connection.execute(text("COMMIT"))

connection.execute(text(f'SET search_path TO "{schema}"'))

context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata, # type: ignore
version_table_schema=schema,
include_schemas=True,
compare_type=True,
Expand All @@ -89,6 +92,7 @@ def do_run_migrations(connection: Connection) -> None:
with context.begin_transaction():
context.run_migrations()


async def run_async_migrations() -> None:
print("Running async migrations")
"""Run migrations in 'online' mode."""
Expand All @@ -102,10 +106,12 @@ async def run_async_migrations() -> None:

await connectable.dispose()


def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())


if context.is_offline_mode():
run_migrations_offline()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def upgrade() -> None:
sa.text('SELECT id, chosen_assistants FROM "user"')
)
op.drop_column(
'user',
"user",
"chosen_assistants",
)
op.add_column(
'user',
"user",
sa.Column(
"chosen_assistants",
postgresql.JSONB(astext_type=sa.Text()),
Expand All @@ -38,7 +38,7 @@ def upgrade() -> None:
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
"UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id"
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
Expand All @@ -47,20 +47,20 @@ def upgrade() -> None:
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text('SELECT id, chosen_assistants FROM user')
sa.text("SELECT id, chosen_assistants FROM user")
)
op.drop_column(
'user',
"user",
"chosen_assistants",
)
op.add_column(
'user',
"user",
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
"UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id"
),
{"chosen_assistants": chosen_assistants, "id": id},
)
14 changes: 7 additions & 7 deletions backend/alembic/versions/dbaa756c2ccf_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"embedding_model",
Expand Down Expand Up @@ -68,24 +69,22 @@ def upgrade() -> None:
column("query_prefix", String),
column("passage_prefix", String),
column("index_name", String),
column("status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)),
column(
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
),
)

# Insert the old embedding model
op.bulk_insert(
EmbeddingModel,
[
old_embedding_model
],
[old_embedding_model],
)

# If the user has not overridden the embedding model, insert the new default model
if not user_overridden_embedding_model:
op.bulk_insert(
EmbeddingModel,
[
new_embedding_model
],
[new_embedding_model],
)

op.add_column(
Expand Down Expand Up @@ -123,6 +122,7 @@ def upgrade() -> None:
postgresql_where=sa.text("status = 'FUTURE'"),
)


def downgrade() -> None:
op.drop_constraint(
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
Expand Down
5 changes: 4 additions & 1 deletion backend/alembic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or ""


def _get_trimmed_key(key: str) -> bytes:
encoded_key = key.encode()
key_length = len(encoded_key)
Expand All @@ -20,6 +21,7 @@ def _get_trimmed_key(key: str) -> bytes:

return encoded_key


def encrypt_string(input_str: str) -> bytes:
if not ENCRYPTION_KEY_SECRET:
return input_str.encode()
Expand All @@ -35,8 +37,10 @@ def encrypt_string(input_str: str) -> bytes:

return iv + encrypted_data


NUM_POSTPROCESSED_RESULTS = 20


class IndexModelStatus(str, Enum):
PAST = "PAST"
PRESENT = "PRESENT"
Expand All @@ -56,7 +60,6 @@ class SearchType(str, Enum):
SEMANTIC = "semantic"



class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
INGESTION_API = "ingestion_api"
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def send_user_verification_email(
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)


def verify_sso_token(token: str) -> dict:
try:
payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"])
Expand Down Expand Up @@ -232,14 +233,13 @@ async def create_user_session(user: User, tenant_id: str) -> str:
"sub": str(user.id),
"email": user.email,
"tenant_id": tenant_id,
"exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS)
"exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS),
}

token = jwt.encode(payload, SECRET_JWT_KEY, algorithm="HS256")
return token



class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
Expand Down
3 changes: 1 addition & 2 deletions backend/danswer/background/celery/celery_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from danswer.background.update import get_all_tenant_ids
import logging
import time
Expand Down Expand Up @@ -476,5 +475,5 @@ def schedule_tenant_tasks() -> None:
}
)

schedule_tenant_tasks()

schedule_tenant_tasks()
8 changes: 5 additions & 3 deletions backend/danswer/background/celery/tasks/pruning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ def check_for_prune_task(tenant_id: str | None) -> None:
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
tenant_id=tenant_id
tenant_id=tenant_id,
)
)


@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str | None) -> None:
def prune_documents_task(
connector_id: int, credential_id: int, tenant_id: str | None
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
Expand Down Expand Up @@ -113,7 +115,7 @@ def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str |
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
tenant_id=tenant_id
tenant_id=tenant_id,
)
except Exception as e:
task_logger.exception(
Expand Down
8 changes: 6 additions & 2 deletions backend/danswer/background/connector_deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def delete_connector_credential_pair_batch(
connector_id: int,
credential_id: int,
document_index: DocumentIndex,
tenant_id: str | None
tenant_id: str | None,
) -> None:
"""
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
Expand Down Expand Up @@ -134,7 +134,11 @@ def delete_connector_credential_pair_batch(
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int, tenant_id: str | None
self: Task,
document_id: str,
connector_id: int,
credential_id: int,
tenant_id: str | None,
) -> bool:
task_logger.info(f"document_id={document_id}")

Expand Down
33 changes: 18 additions & 15 deletions backend/danswer/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import time
import traceback
from datetime import datetime
Expand Down Expand Up @@ -44,7 +43,7 @@ def _get_connector_runner(
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
tenant_id: str | None
tenant_id: str | None,
) -> ConnectorRunner:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
Expand Down Expand Up @@ -81,9 +80,7 @@ def _get_connector_runner(


def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None
db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None
) -> None:
"""
1. Get documents which are either new or updated from specified application
Expand All @@ -104,7 +101,6 @@ def _run_indexing(
primary_index_name=index_name, secondary_index_name=None
)


embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)
Expand Down Expand Up @@ -173,7 +169,7 @@ def _run_indexing(
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id
tenant_id=tenant_id,
)

all_connector_doc_ids: set[str] = set()
Expand Down Expand Up @@ -201,7 +197,9 @@ def _run_indexing(
db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(f"Index Attempt was canceled, status is {index_attempt.status}")
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt.status}"
)

batch_description = []
for doc in doc_batch:
Expand Down Expand Up @@ -388,14 +386,13 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA

return attempt


def run_indexing_entrypoint(
index_attempt_id: int,
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
) -> None:


try:
if is_ee:
global_version.set_ee()
Expand All @@ -410,19 +407,25 @@ def run_indexing_entrypoint(
attempt = _prepare_index_attempt(db_session, index_attempt_id)

logger.info(
f"Indexing starting for tenant {tenant_id}: " if tenant_id is not None else "" +
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)

_run_indexing(db_session, attempt, tenant_id)

logger.info(
f"Indexing finished for tenant {tenant_id}: " if tenant_id is not None else "" +
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}")
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
)
Loading

0 comments on commit 7a0ba35

Please sign in to comment.