diff --git a/backend/alembic/versions/ac5eaac849f9_add_last_pruned_to_connector_table.py b/backend/alembic/versions/ac5eaac849f9_add_last_pruned_to_connector_table.py new file mode 100644 index 00000000000..b2c33e1688d --- /dev/null +++ b/backend/alembic/versions/ac5eaac849f9_add_last_pruned_to_connector_table.py @@ -0,0 +1,27 @@ +"""add last_pruned to the connector_credential_pair table + +Revision ID: ac5eaac849f9 +Revises: 52a219fb5233 +Create Date: 2024-09-10 15:04:26.437118 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "ac5eaac849f9" +down_revision = "46b7a812670f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # last pruned represents the last time the connector was pruned + op.add_column( + "connector_credential_pair", + sa.Column("last_pruned", sa.DateTime(timezone=True), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("connector_credential_pair", "last_pruned") diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index 0440f275c36..01da90642b5 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -19,6 +19,7 @@ from danswer.background.celery.celery_redis import RedisConnectorCredentialPair from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_utils import celery_is_worker_primary @@ -108,6 +109,14 @@ def celery_task_postrun( r.srem(rcd.taskset_key, task_id) return + if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX): + r = redis_pool.get_client() + cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id) + if cc_pair_id is not None: + rcp = RedisConnectorPruning(cc_pair_id) + r.srem(rcp.taskset_key, task_id) + return + @beat_init.connect def on_beat_init(sender: Any, **kwargs: Any) -> None: @@ -240,6 +249,18 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): + r.delete(key) + @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: @@ -334,7 +355,11 @@ def on_setup_logging( class HubPeriodicTask(bootsteps.StartStopStep): """Regularly reacquires the primary worker lock outside of the task queue. - Use the task_logger in this class to avoid double logging.""" + Use the task_logger in this class to avoid double logging. + + This cannot be done inside a regular beat task because it must run on schedule and + a queue of existing work would starve the task from running. + """ # it's unclear to me whether using the hub's timer or the bootstep timer is better requires = {"celery.worker.components:Hub"} @@ -368,8 +393,6 @@ def run_periodic_task(self, worker: Any) -> None: lock: redis.lock.Lock = worker.primary_worker_lock - task_logger.info("Reacquiring primary worker lock.") - if lock.owned(): task_logger.debug("Reacquiring primary worker lock.") lock.reacquire() @@ -411,6 +434,7 @@ def stop(self, worker: Any) -> None: "danswer.background.celery.tasks.connector_deletion", "danswer.background.celery.tasks.periodic", "danswer.background.celery.tasks.pruning", + "danswer.background.celery.tasks.shared", "danswer.background.celery.tasks.vespa", ] ) @@ -431,7 +455,7 @@ def stop(self, worker: Any) -> None: "task": "check_for_connector_deletion_task", # don't need to check too often, since we kick off a deletion initially # during the API call that actually marks the CC pair for deletion - "schedule": timedelta(minutes=1), + "schedule": timedelta(seconds=60), "options": {"priority": DanswerCeleryPriority.HIGH}, }, } @@ -439,8 +463,8 @@ def stop(self, worker: Any) -> None: celery_app.conf.beat_schedule.update( { "check-for-prune": { - "task": "check_for_prune_task", - "schedule": timedelta(seconds=5), + "task": "check_for_prune_task_2", + "schedule": timedelta(seconds=60), "options": {"priority": DanswerCeleryPriority.HIGH}, }, } diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py index 1d837bd51e0..d2ba49cebd5 100644 --- a/backend/danswer/background/celery/celery_redis.py +++ b/backend/danswer/background/celery/celery_redis.py @@ -343,6 +343,110 @@ def generate_tasks( return len(async_results) +class RedisConnectorPruning(RedisObjectHelper): + """Celery will kick off a long running generator task to crawl the connector and + find any missing docs, which will each then get a new cleanup task. The progress of + those tasks will then be monitored to completion. + + Example rough happy path order: + Check connectorpruning_fence_1 + Send generator task with id connectorpruning+generator_1_{uuid} + + generator runs connector with callbacks that increment connectorpruning_generator_progress_1 + generator creates many subtasks with id connectorpruning+sub_1_{uuid} + in taskset connectorpruning_taskset_1 + on completion, generator sets connectorpruning_generator_complete_1 + + celery postrun removes subtasks from taskset + monitor beat task cleans up when taskset reaches 0 items + """ + + PREFIX = "connectorpruning" + FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process + GENERATOR_TASK_PREFIX = PREFIX + "+generator" + + TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's + SUBTASK_PREFIX = PREFIX + "+sub" + + GENERATOR_PROGRESS_PREFIX = ( + PREFIX + "_generator_progress" + ) # a signal that contains generator progress + GENERATOR_COMPLETE_PREFIX = ( + PREFIX + "_generator_complete" + ) # a signal that the generator has finished + + def __init__(self, id: int) -> None: + super().__init__(id) + self.documents_to_prune: set[str] = set() + + @property + def generator_task_id_prefix(self) -> str: + return f"{self.GENERATOR_TASK_PREFIX}_{self._id}" + + @property + def generator_progress_key(self) -> str: + # example: connectorpruning_generator_progress_1 + return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}" + + @property + def generator_complete_key(self) -> str: + # example: connectorpruning_generator_complete_1 + return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}" + + @property + def subtask_id_prefix(self) -> str: + return f"{self.SUBTASK_PREFIX}_{self._id}" + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock | None, + ) -> int | None: + last_lock_time = time.monotonic() + + async_results = [] + cc_pair = get_connector_credential_pair_from_id(self._id, db_session) + if not cc_pair: + return None + + for doc_id in self.documents_to_prune: + current_time = time.monotonic() + if lock and current_time - last_lock_time >= ( + CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 + ): + lock.reacquire() + last_lock_time = current_time + + # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" + # the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" + # we prefix the task id so it's easier to keep track of who created the task + # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" + custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}" + + # add to the tracking taskset in redis BEFORE creating the celery task. + # note that for the moment we are using a single taskset key, not differentiated by cc_pair id + redis_client.sadd(self.taskset_key, custom_task_id) + + # Priority on sync's triggered by new indexing should be medium + result = celery_app.send_task( + "document_by_cc_pair_cleanup_task", + kwargs=dict( + document_id=doc_id, + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + ), + queue=DanswerCeleryQueues.CONNECTOR_DELETION, + task_id=custom_task_id, + priority=DanswerCeleryPriority.MEDIUM, + ) + + async_results.append(result) + + return len(async_results) + + def celery_get_queue_length(queue: str, r: Redis) -> int: """This is a redis specific way to get the length of a celery queue. It is priority aware and knows how to count across the multiple redis lists diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index 9ee282e1af3..638fa9a66d2 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from datetime import datetime from datetime import timezone from typing import Any @@ -5,8 +6,7 @@ from sqlalchemy.orm import Session from danswer.background.celery.celery_redis import RedisConnectorDeletion -from danswer.background.task_utils import name_cc_prune_task -from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING +from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, @@ -17,14 +17,9 @@ from danswer.connectors.interfaces import PollConnector from danswer.connectors.models import Document from danswer.db.connector_credential_pair import get_connector_credential_pair -from danswer.db.engine import get_db_current_time +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.enums import TaskStatus -from danswer.db.models import Connector -from danswer.db.models import Credential from danswer.db.models import TaskQueueState -from danswer.db.tasks import check_task_is_live_and_not_timed_out -from danswer.db.tasks import get_latest_task -from danswer.db.tasks import get_latest_task_by_type from danswer.redis.redis_pool import RedisPool from danswer.server.documents.models import DeletionAttemptSnapshot from danswer.utils.logger import setup_logger @@ -33,6 +28,24 @@ redis_pool = RedisPool() +# TODO: make this a member of RedisConnectorPruning +def cc_pair_is_pruning(cc_pair_id: int, db_session: Session) -> bool: + # + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id=cc_pair_id, db_session=db_session + ) + if not cc_pair: + raise ValueError(f"cc_pair_id {cc_pair_id} does not exist.") + + rcp = RedisConnectorPruning(cc_pair.id) + + r = redis_pool.get_client() + if r.exists(rcp.fence_key): + return True + + return False + + def _get_deletion_status( connector_id: int, credential_id: int, db_session: Session ) -> TaskQueueState | None: @@ -70,72 +83,19 @@ def get_deletion_attempt_snapshot( ) -def skip_cc_pair_pruning_by_task( - pruning_task: TaskQueueState | None, db_session: Session -) -> bool: - """task should be the latest prune task for this cc_pair""" - if not ALLOW_SIMULTANEOUS_PRUNING: - # if only one prune is allowed at any time, then check to see if any prune - # is active - pruning_type_task_name = name_cc_prune_task() - last_pruning_type_task = get_latest_task_by_type( - pruning_type_task_name, db_session - ) - - if last_pruning_type_task and check_task_is_live_and_not_timed_out( - last_pruning_type_task, db_session - ): - return True - - if pruning_task and check_task_is_live_and_not_timed_out(pruning_task, db_session): - # if the last task is live right now, we shouldn't start a new one - return True - - return False - - -def should_prune_cc_pair( - connector: Connector, credential: Credential, db_session: Session -) -> bool: - if not connector.prune_freq: - return False - - pruning_task_name = name_cc_prune_task( - connector_id=connector.id, credential_id=credential.id - ) - last_pruning_task = get_latest_task(pruning_task_name, db_session) - - if skip_cc_pair_pruning_by_task(last_pruning_task, db_session): - return False - - current_db_time = get_db_current_time(db_session) - - if not last_pruning_task: - # If the connector has never been pruned, then compare vs when the connector - # was created - time_since_initialization = current_db_time - connector.time_created - if time_since_initialization.total_seconds() >= connector.prune_freq: - return True - return False - - if not last_pruning_task.start_time: - # if the last prune task hasn't started, we shouldn't start a new one - return False - - # if the last prune task has a start time, then compare against it to determine - # if we should start - time_since_last_pruning = current_db_time - last_pruning_task.start_time - return time_since_last_pruning.total_seconds() >= connector.prune_freq - - def document_batch_to_ids(doc_batch: list[Document]) -> set[str]: return {doc.id for doc in doc_batch} -def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]: +def extract_ids_from_runnable_connector( + runnable_connector: BaseConnector, + progress_callback: Callable[[int], None] | None = None, +) -> set[str]: """ If the PruneConnector hasnt been implemented for the given connector, just pull - all docs using the load_from_state and grab out the IDs + all docs using the load_from_state and grab out the IDs. + + Optionally, a callback can be passed to handle the length of each document batch. """ all_connector_doc_ids: set[str] = set() @@ -158,6 +118,8 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60 )(document_batch_to_ids) for doc_batch in doc_batch_generator: + if progress_callback: + progress_callback(len(doc_batch)) all_connector_doc_ids.update(doc_batch_processing_func(doc_batch)) return all_connector_doc_ids @@ -177,9 +139,10 @@ def celery_is_listening_to_queue(worker: Any, name: str) -> bool: def celery_is_worker_primary(worker: Any) -> bool: - """There are multiple approaches that could be taken, but the way we do it is to - check the hostname set for the celery worker, either in celeryconfig.py or on the - command line.""" + """There are multiple approaches that could be taken to determine if a celery worker + is 'primary', as defined by us. But the way we do it is to check the hostname set + for the celery worker, which can be done either in celeryconfig.py or on the + command line with '--hostname'.""" hostname = worker.hostname if hostname.startswith("light"): return False diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index 655487f7168..c6df6c70ac9 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -1,12 +1,12 @@ import redis from celery import shared_task from celery.exceptions import SoftTimeLimitExceeded -from celery.utils.log import get_task_logger from redis import Redis from sqlalchemy.orm import Session from sqlalchemy.orm.exc import ObjectDeletedError from danswer.background.celery.celery_app import celery_app +from danswer.background.celery.celery_app import task_logger from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT @@ -22,9 +22,6 @@ redis_pool = RedisPool() -# use this within celery tasks to get celery task specific logging -task_logger = get_task_logger(__name__) - @shared_task( name="check_for_connector_deletion_task", diff --git a/backend/danswer/background/celery/tasks/periodic/tasks.py b/backend/danswer/background/celery/tasks/periodic/tasks.py index bd3b082aeb8..99b1cab7e77 100644 --- a/backend/danswer/background/celery/tasks/periodic/tasks.py +++ b/backend/danswer/background/celery/tasks/periodic/tasks.py @@ -7,18 +7,15 @@ from celery import shared_task from celery.contrib.abortable import AbortableTask # type: ignore from celery.exceptions import TaskRevokedError -from celery.utils.log import get_task_logger from sqlalchemy import inspect from sqlalchemy import text from sqlalchemy.orm import Session +from danswer.background.celery.celery_app import task_logger from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import PostgresAdvisoryLocks from danswer.db.engine import get_sqlalchemy_engine # type: ignore -# use this within celery tasks to get celery task specific logging -task_logger = get_task_logger(__name__) - @shared_task( name="kombu_message_cleanup_task", diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 2f840e430ae..aff5ff15044 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -1,61 +1,167 @@ +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from uuid import uuid4 + +import redis from celery import shared_task -from celery.utils.log import get_task_logger +from celery.exceptions import SoftTimeLimitExceeded +from redis import Redis from sqlalchemy.orm import Session from danswer.background.celery.celery_app import celery_app +from danswer.background.celery.celery_app import task_logger +from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector -from danswer.background.celery.celery_utils import should_prune_cc_pair -from danswer.background.connector_deletion import delete_connector_credential_pair_batch -from danswer.background.task_utils import build_celery_task_wrapper -from danswer.background.task_utils import name_cc_prune_task +from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerRedisLocks from danswer.connectors.factory import instantiate_connector from danswer.connectors.models import InputType from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.document import get_documents_for_connector_credential_pair from danswer.db.engine import get_sqlalchemy_engine -from danswer.document_index.document_index_utils import get_both_index_names -from danswer.document_index.factory import get_default_document_index - +from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.models import ConnectorCredentialPair +from danswer.redis.redis_pool import RedisPool -# use this within celery tasks to get celery task specific logging -task_logger = get_task_logger(__name__) +redis_pool = RedisPool() @shared_task( - name="check_for_prune_task", + name="check_for_prune_task_2", soft_time_limit=JOB_TIMEOUT, ) -def check_for_prune_task() -> None: - """Runs periodically to check if any prune tasks should be run and adds them - to the queue""" - - with Session(get_sqlalchemy_engine()) as db_session: - all_cc_pairs = get_connector_credential_pairs(db_session) - - for cc_pair in all_cc_pairs: - if should_prune_cc_pair( - connector=cc_pair.connector, - credential=cc_pair.credential, - db_session=db_session, - ): - task_logger.info(f"Pruning the {cc_pair.connector.name} connector") - - prune_documents_task.apply_async( - kwargs=dict( - connector_id=cc_pair.connector.id, - credential_id=cc_pair.credential.id, - ) +def check_for_prune_task_2() -> None: + r = redis_pool.get_client() + + lock_beat = r.lock( + DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return + + with Session(get_sqlalchemy_engine()) as db_session: + cc_pairs = get_connector_credential_pairs(db_session) + for cc_pair in cc_pairs: + tasks_created = ccpair_pruning_generator_task_creation_helper( + cc_pair, db_session, r, lock_beat ) - - -@build_celery_task_wrapper(name_cc_prune_task) -@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT) -def prune_documents_task(connector_id: int, credential_id: int) -> None: + if not tasks_created: + continue + + task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}") + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + except Exception: + task_logger.exception("Unexpected exception") + finally: + if lock_beat.owned(): + lock_beat.release() + + +def ccpair_pruning_generator_task_creation_helper( + cc_pair: ConnectorCredentialPair, + db_session: Session, + r: Redis, + lock_beat: redis.lock.Lock, +) -> int | None: + """Returns an int if pruning is triggered. + The int represents the number of prune tasks generated (in this case, only one + because the task is a long running generator task.) + Returns None if no pruning is triggered (due to not being needed or + other reasons such as simultaneous pruning restrictions. + + Checks for scheduling related conditions, then delegates the rest of the checks to + try_creating_prune_generator_task. + """ + + lock_beat.reacquire() + + # skip pruning if no prune frequency is set + # pruning can still be forced via the API which will run a pruning task directly + if not cc_pair.connector.prune_freq: + return None + + # skip pruning if the next scheduled prune time hasn't been reached yet + last_pruned = cc_pair.last_pruned + if not last_pruned: + # if never pruned, use the connector time created as the last_pruned time + last_pruned = cc_pair.connector.time_created + + next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq) + if datetime.now(timezone.utc) < next_prune: + return None + + return try_creating_prune_generator_task(cc_pair, db_session, r) + + +def try_creating_prune_generator_task( + cc_pair: ConnectorCredentialPair, + db_session: Session, + r: Redis, +) -> int | None: + """Checks for any conditions that should block the pruning generator task from being + created, then creates the task. + + Does not check for scheduling related conditions as this function + is used to trigger prunes immediately. + """ + + if not ALLOW_SIMULTANEOUS_PRUNING: + for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): + return None + + rcp = RedisConnectorPruning(cc_pair.id) + + # skip pruning if already pruning + if r.exists(rcp.fence_key): + return None + + # skip pruning if the cc_pair is deleting + db_session.refresh(cc_pair) + if cc_pair.status == ConnectorCredentialPairStatus.DELETING: + return None + + # add a long running generator task to the queue + r.delete(rcp.generator_complete_key) + r.delete(rcp.taskset_key) + + custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}" + + celery_app.send_task( + "connector_pruning_generator_task", + kwargs=dict( + connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id + ), + queue=DanswerCeleryQueues.CONNECTOR_PRUNING, + task_id=custom_task_id, + priority=DanswerCeleryPriority.LOW, + ) + + # set this only after all tasks have been added + r.set(rcp.fence_key, 1) + return 1 + + +@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT) +def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None: """connector pruning task. For a cc pair, this task pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" + + r = redis_pool.get_client() + with Session(get_sqlalchemy_engine()) as db_session: try: cc_pair = get_connector_credential_pair( @@ -70,6 +176,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: ) return + rcp = RedisConnectorPruning(cc_pair.id) + + # Define the callback function + def redis_increment_callback(amount: int) -> None: + r.incrby(rcp.generator_progress_key, amount) + runnable_connector = instantiate_connector( db_session, cc_pair.connector.source, @@ -78,10 +190,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: cc_pair.credential, ) + # a list of docs in the source all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector( - runnable_connector + runnable_connector, redis_increment_callback ) + # a list of docs in our local index all_indexed_document_ids = { doc.id for doc in get_documents_for_connector_credential_pair( @@ -91,30 +205,37 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: ) } + # generate list of docs to remove (no longer in the source) doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids) - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + task_logger.info( + f"Pruning set collected: " + f"cc_pair_id={cc_pair.id} " + f"docs_to_remove={len(doc_ids_to_remove)} " + f"doc_source={cc_pair.connector.source}" ) - if len(doc_ids_to_remove) == 0: - task_logger.info( - f"No docs to prune from {cc_pair.connector.source} connector" - ) - return + rcp.documents_to_prune = set(doc_ids_to_remove) task_logger.info( - f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector" + f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}" ) - delete_connector_credential_pair_batch( - document_ids=doc_ids_to_remove, - connector_id=connector_id, - credential_id=credential_id, - document_index=document_index, + tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None) + if tasks_generated is None: + return None + + task_logger.info( + f"RedisConnectorPruning.generate_tasks finished. " + f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" ) + + r.set(rcp.generator_complete_key, tasks_generated) except Exception as e: task_logger.exception( f"Failed to run pruning for connector id {connector_id}." ) + + r.delete(rcp.generator_progress_key) + r.delete(rcp.taskset_key) + r.delete(rcp.fence_key) raise e diff --git a/backend/danswer/background/celery/tasks/shared/tasks.py b/backend/danswer/background/celery/tasks/shared/tasks.py new file mode 100644 index 00000000000..6f86c8959d3 --- /dev/null +++ b/backend/danswer/background/celery/tasks/shared/tasks.py @@ -0,0 +1,104 @@ +from celery import shared_task +from celery import Task +from celery.exceptions import SoftTimeLimitExceeded +from sqlalchemy.orm import Session + +from danswer.access.access import get_access_for_document +from danswer.background.celery.celery_app import task_logger +from danswer.db.document import delete_document_by_connector_credential_pair__no_commit +from danswer.db.document import delete_documents_complete__no_commit +from danswer.db.document import get_document +from danswer.db.document import get_document_connector_count +from danswer.db.document import mark_document_as_synced +from danswer.db.document_set import fetch_document_sets_for_document +from danswer.db.engine import get_sqlalchemy_engine +from danswer.document_index.document_index_utils import get_both_index_names +from danswer.document_index.factory import get_default_document_index +from danswer.document_index.interfaces import UpdateRequest +from danswer.server.documents.models import ConnectorCredentialPairIdentifier + + +@shared_task( + name="document_by_cc_pair_cleanup_task", + bind=True, + soft_time_limit=45, + time_limit=60, + max_retries=3, +) +def document_by_cc_pair_cleanup_task( + self: Task, document_id: str, connector_id: int, credential_id: int +) -> bool: + task_logger.info(f"document_id={document_id}") + + try: + with Session(get_sqlalchemy_engine()) as db_session: + curr_ind_name, sec_ind_name = get_both_index_names(db_session) + document_index = get_default_document_index( + primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + ) + + count = get_document_connector_count(db_session, document_id) + if count == 1: + # count == 1 means this is the only remaining cc_pair reference to the doc + # delete it from vespa and the db + document_index.delete(doc_ids=[document_id]) + delete_documents_complete__no_commit( + db_session=db_session, + document_ids=[document_id], + ) + elif count > 1: + # count > 1 means the document still has cc_pair references + doc = get_document(document_id, db_session) + if not doc: + return False + + # the below functions do not include cc_pairs being deleted. + # i.e. they will correctly omit access for the current cc_pair + doc_access = get_access_for_document( + document_id=document_id, db_session=db_session + ) + + doc_sets = fetch_document_sets_for_document(document_id, db_session) + update_doc_sets: set[str] = set(doc_sets) + + update_request = UpdateRequest( + document_ids=[document_id], + document_sets=update_doc_sets, + access=doc_access, + boost=doc.boost, + hidden=doc.hidden, + ) + + # update Vespa. OK if doc doesn't exist. Raises exception otherwise. + document_index.update_single(update_request=update_request) + + # there are still other cc_pair references to the doc, so just resync to Vespa + delete_document_by_connector_credential_pair__no_commit( + db_session=db_session, + document_id=document_id, + connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( + connector_id=connector_id, + credential_id=credential_id, + ), + ) + + mark_document_as_synced(document_id, db_session) + else: + pass + + # update_docs_last_modified__no_commit( + # db_session=db_session, + # document_ids=[document_id], + # ) + + db_session.commit() + except SoftTimeLimitExceeded: + task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}") + except Exception as e: + task_logger.exception("Unexpected exception") + + # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 + countdown = 2 ** (self.request.retries + 4) + self.retry(exc=e, countdown=countdown) + + return True diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index d11d317d0b1..feef67ca4da 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -5,20 +5,22 @@ from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded -from celery.utils.log import get_task_logger from redis import Redis from sqlalchemy.orm import Session from danswer.access.access import get_access_for_document from danswer.background.celery.celery_app import celery_app +from danswer.background.celery.celery_app import task_logger from danswer.background.celery.celery_redis import RedisConnectorCredentialPair from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerRedisLocks from danswer.db.connector import fetch_connector_by_id +from danswer.db.connector import mark_ccpair_as_pruned from danswer.db.connector_credential_pair import add_deletion_failure_message from danswer.db.connector_credential_pair import ( delete_connector_credential_pair__no_commit, @@ -50,9 +52,6 @@ redis_pool = RedisPool() -# use this within celery tasks to get celery task specific logging -task_logger = get_task_logger(__name__) - # celery auto associates tasks created inside another task, # which bloats the result metadata considerably. trail=False prevents this. @@ -280,7 +279,7 @@ def monitor_document_set_taskset( fence_key = key_bytes.decode("utf-8") document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key) if document_set_id is None: - task_logger.warning("could not parse document set id from {key}") + task_logger.warning(f"could not parse document set id from {fence_key}") return rds = RedisDocumentSet(document_set_id) @@ -327,7 +326,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None: fence_key = key_bytes.decode("utf-8") cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key) if cc_pair_id is None: - task_logger.warning("could not parse document set id from {key}") + task_logger.warning(f"could not parse cc_pair_id from {fence_key}") return rcd = RedisConnectorDeletion(cc_pair_id) @@ -417,6 +416,51 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None: r.delete(rcd.fence_key) +def monitor_ccpair_pruning_taskset( + key_bytes: bytes, r: Redis, db_session: Session +) -> None: + fence_key = key_bytes.decode("utf-8") + cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key) + if cc_pair_id is None: + task_logger.warning( + f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}" + ) + return + + rcp = RedisConnectorPruning(cc_pair_id) + + fence_value = r.get(rcp.fence_key) + if fence_value is None: + return + + generator_value = r.get(rcp.generator_complete_key) + if generator_value is None: + return + + try: + initial_count = int(cast(int, generator_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = cast(int, r.scard(rcp.taskset_key)) + task_logger.info( + f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}" + ) + if count > 0: + return + + mark_ccpair_as_pruned(cc_pair_id, db_session) + task_logger.info( + f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}" + ) + + r.delete(rcp.taskset_key) + r.delete(rcp.generator_progress_key) + r.delete(rcp.generator_complete_key) + r.delete(rcp.fence_key) + + @shared_task(name="monitor_vespa_sync", soft_time_limit=300) def monitor_vespa_sync() -> None: """This is a celery beat task that monitors and finalizes metadata sync tasksets. @@ -458,6 +502,9 @@ def monitor_vespa_sync() -> None: ) monitor_usergroup_taskset(key_bytes, r, db_session) + for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): + monitor_ccpair_pruning_taskset(key_bytes, r, db_session) + # uncomment for debugging if needed # r_celery = celery_app.broker_connection().channel().client # length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery) diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index 983a3c129ba..157681ca556 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -10,202 +10,76 @@ connector / credential pair from the access list (6) delete all relevant entries from postgres """ -from celery import shared_task -from celery import Task -from celery.exceptions import SoftTimeLimitExceeded -from celery.utils.log import get_task_logger -from sqlalchemy.orm import Session - -from danswer.access.access import get_access_for_document -from danswer.access.access import get_access_for_documents -from danswer.db.document import delete_document_by_connector_credential_pair__no_commit -from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit -from danswer.db.document import delete_documents_complete__no_commit -from danswer.db.document import get_document -from danswer.db.document import get_document_connector_count -from danswer.db.document import get_document_connector_counts -from danswer.db.document import mark_document_as_synced -from danswer.db.document import prepare_to_modify_documents -from danswer.db.document_set import fetch_document_sets_for_document -from danswer.db.document_set import fetch_document_sets_for_documents -from danswer.db.engine import get_sqlalchemy_engine -from danswer.document_index.document_index_utils import get_both_index_names -from danswer.document_index.factory import get_default_document_index -from danswer.document_index.interfaces import DocumentIndex -from danswer.document_index.interfaces import UpdateRequest -from danswer.server.documents.models import ConnectorCredentialPairIdentifier -from danswer.utils.logger import setup_logger - -logger = setup_logger() - -# use this within celery tasks to get celery task specific logging -task_logger = get_task_logger(__name__) - -_DELETION_BATCH_SIZE = 1000 - - -def delete_connector_credential_pair_batch( - document_ids: list[str], - connector_id: int, - credential_id: int, - document_index: DocumentIndex, -) -> None: - """ - Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore - it gets permanently deleted. - """ - with Session(get_sqlalchemy_engine()) as db_session: - # acquire lock for all documents in this batch so that indexing can't - # override the deletion - with prepare_to_modify_documents( - db_session=db_session, document_ids=document_ids - ): - document_connector_counts = get_document_connector_counts( - db_session=db_session, document_ids=document_ids - ) - - # figure out which docs need to be completely deleted - document_ids_to_delete = [ - document_id - for document_id, cnt in document_connector_counts - if cnt == 1 - ] - logger.debug(f"Deleting documents: {document_ids_to_delete}") - - document_index.delete(doc_ids=document_ids_to_delete) - - delete_documents_complete__no_commit( - db_session=db_session, - document_ids=document_ids_to_delete, - ) - - # figure out which docs need to be updated - document_ids_to_update = [ - document_id for document_id, cnt in document_connector_counts if cnt > 1 - ] - - # maps document id to list of document set names - new_doc_sets_for_documents: dict[str, set[str]] = { - document_id_and_document_set_names_tuple[0]: set( - document_id_and_document_set_names_tuple[1] - ) - for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents( - db_session=db_session, - document_ids=document_ids_to_update, - ) - } - - # determine future ACLs for documents in batch - access_for_documents = get_access_for_documents( - document_ids=document_ids_to_update, - db_session=db_session, - ) - - # update Vespa - logger.debug(f"Updating documents: {document_ids_to_update}") - update_requests = [ - UpdateRequest( - document_ids=[document_id], - access=access, - document_sets=new_doc_sets_for_documents[document_id], - ) - for document_id, access in access_for_documents.items() - ] - document_index.update(update_requests=update_requests) - - # clean up Postgres - delete_documents_by_connector_credential_pair__no_commit( - db_session=db_session, - document_ids=document_ids_to_update, - connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( - connector_id=connector_id, - credential_id=credential_id, - ), - ) - db_session.commit() - - -@shared_task( - name="document_by_cc_pair_cleanup_task", - bind=True, - soft_time_limit=45, - time_limit=60, - max_retries=3, -) -def document_by_cc_pair_cleanup_task( - self: Task, document_id: str, connector_id: int, credential_id: int -) -> bool: - task_logger.info(f"document_id={document_id}") - - try: - with Session(get_sqlalchemy_engine()) as db_session: - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - - count = get_document_connector_count(db_session, document_id) - if count == 1: - # count == 1 means this is the only remaining cc_pair reference to the doc - # delete it from vespa and the db - document_index.delete(doc_ids=[document_id]) - delete_documents_complete__no_commit( - db_session=db_session, - document_ids=[document_id], - ) - elif count > 1: - # count > 1 means the document still has cc_pair references - doc = get_document(document_id, db_session) - if not doc: - return False - - # the below functions do not include cc_pairs being deleted. - # i.e. they will correctly omit access for the current cc_pair - doc_access = get_access_for_document( - document_id=document_id, db_session=db_session - ) - - doc_sets = fetch_document_sets_for_document(document_id, db_session) - update_doc_sets: set[str] = set(doc_sets) - - update_request = UpdateRequest( - document_ids=[document_id], - document_sets=update_doc_sets, - access=doc_access, - boost=doc.boost, - hidden=doc.hidden, - ) - - # update Vespa. OK if doc doesn't exist. Raises exception otherwise. - document_index.update_single(update_request=update_request) - - # there are still other cc_pair references to the doc, so just resync to Vespa - delete_document_by_connector_credential_pair__no_commit( - db_session=db_session, - document_id=document_id, - connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( - connector_id=connector_id, - credential_id=credential_id, - ), - ) - - mark_document_as_synced(document_id, db_session) - else: - pass - - # update_docs_last_modified__no_commit( - # db_session=db_session, - # document_ids=[document_id], - # ) - - db_session.commit() - except SoftTimeLimitExceeded: - task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}") - except Exception as e: - task_logger.exception("Unexpected exception") - - # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 - countdown = 2 ** (self.request.retries + 4) - self.retry(exc=e, countdown=countdown) - - return True +# logger = setup_logger() +# _DELETION_BATCH_SIZE = 1000 +# def delete_connector_credential_pair_batch( +# document_ids: list[str], +# connector_id: int, +# credential_id: int, +# document_index: DocumentIndex, +# ) -> None: +# """ +# Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore +# it gets permanently deleted. +# """ +# with Session(get_sqlalchemy_engine()) as db_session: +# # acquire lock for all documents in this batch so that indexing can't +# # override the deletion +# with prepare_to_modify_documents( +# db_session=db_session, document_ids=document_ids +# ): +# document_connector_counts = get_document_connector_counts( +# db_session=db_session, document_ids=document_ids +# ) +# # figure out which docs need to be completely deleted +# document_ids_to_delete = [ +# document_id +# for document_id, cnt in document_connector_counts +# if cnt == 1 +# ] +# logger.debug(f"Deleting documents: {document_ids_to_delete}") +# document_index.delete(doc_ids=document_ids_to_delete) +# delete_documents_complete__no_commit( +# db_session=db_session, +# document_ids=document_ids_to_delete, +# ) +# # figure out which docs need to be updated +# document_ids_to_update = [ +# document_id for document_id, cnt in document_connector_counts if cnt > 1 +# ] +# # maps document id to list of document set names +# new_doc_sets_for_documents: dict[str, set[str]] = { +# document_id_and_document_set_names_tuple[0]: set( +# document_id_and_document_set_names_tuple[1] +# ) +# for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents( +# db_session=db_session, +# document_ids=document_ids_to_update, +# ) +# } +# # determine future ACLs for documents in batch +# access_for_documents = get_access_for_documents( +# document_ids=document_ids_to_update, +# db_session=db_session, +# ) +# # update Vespa +# logger.debug(f"Updating documents: {document_ids_to_update}") +# update_requests = [ +# UpdateRequest( +# document_ids=[document_id], +# access=access, +# document_sets=new_doc_sets_for_documents[document_id], +# ) +# for document_id, access in access_for_documents.items() +# ] +# document_index.update(update_requests=update_requests) +# # clean up Postgres +# delete_documents_by_connector_credential_pair__no_commit( +# db_session=db_session, +# document_ids=document_ids_to_update, +# connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( +# connector_id=connector_id, +# credential_id=credential_id, +# ), +# ) +# db_session.commit() diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index e34b8b894d5..b66ca61e1d2 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -182,10 +182,9 @@ class PostgresAdvisoryLocks(Enum): class DanswerCeleryQueues: - VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator" - VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator" VESPA_METADATA_SYNC = "vespa_metadata_sync" CONNECTOR_DELETION = "connector_deletion" + CONNECTOR_PRUNING = "connector_pruning" class DanswerRedisLocks: @@ -193,7 +192,8 @@ class DanswerRedisLocks: CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat" MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat" CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat" - MONITOR_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:monitor_connector_deletion_beat" + CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat" + # MONITOR_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:monitor_connector_deletion_beat" class DanswerCeleryPriority(int, Enum): diff --git a/backend/danswer/db/connector.py b/backend/danswer/db/connector.py index 89e6977103e..0f777d30ec9 100644 --- a/backend/danswer/db/connector.py +++ b/backend/danswer/db/connector.py @@ -1,3 +1,5 @@ +from datetime import datetime +from datetime import timezone from typing import cast from sqlalchemy import and_ @@ -268,3 +270,15 @@ def create_initial_default_connector(db_session: Session) -> None: ) db_session.add(connector) db_session.commit() + + +def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None: + stmt = select(ConnectorCredentialPair).where( + ConnectorCredentialPair.id == cc_pair_id + ) + cc_pair = db_session.scalar(stmt) + if cc_pair is None: + raise ValueError(f"No cc_pair with ID: {cc_pair_id}") + + cc_pair.last_pruned = datetime.now(timezone.utc) + db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index fff6b12336d..25d2d371778 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -414,6 +414,15 @@ class ConnectorCredentialPair(Base): last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) + + # last successful prune + last_pruned: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True, index=True + ) + + # # flag to trigger pruning + # needs_pruning: Mapped[bool] = mapped_column(Boolean, default=False) + total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0) connector: Mapped["Connector"] = relationship( diff --git a/backend/danswer/redis/redis_pool.py b/backend/danswer/redis/redis_pool.py index 45ab39d8a02..31be4f21ce2 100644 --- a/backend/danswer/redis/redis_pool.py +++ b/backend/danswer/redis/redis_pool.py @@ -13,7 +13,7 @@ from danswer.configs.app_configs import REDIS_SSL_CA_CERTS from danswer.configs.app_configs import REDIS_SSL_CERT_REQS -REDIS_POOL_MAX_CONNECTIONS = 10 +REDIS_POOL_MAX_CONNECTIONS = 16 class RedisPool: diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 428666751a4..e5f848a4ec1 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -10,9 +10,11 @@ from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user +from danswer.background.celery.celery_utils import cc_pair_is_pruning from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot -from danswer.background.celery.celery_utils import skip_cc_pair_pruning_by_task -from danswer.background.task_utils import name_cc_prune_task +from danswer.background.celery.tasks.pruning.tasks import ( + try_creating_prune_generator_task, +) from danswer.db.connector_credential_pair import add_credential_to_connector from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.connector_credential_pair import remove_credential_from_connector @@ -29,9 +31,8 @@ from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id from danswer.db.models import User -from danswer.db.tasks import get_latest_task +from danswer.redis.redis_pool import RedisPool from danswer.server.documents.models import CCPairFullInfo -from danswer.server.documents.models import CCPairPruningTask from danswer.server.documents.models import CCStatusUpdateRequest from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairMetadata @@ -43,6 +44,8 @@ logger = setup_logger() router = APIRouter(prefix="/manage") +redis_pool = RedisPool() + @router.get("/admin/cc-pair/{cc_pair_id}/index-attempts") def get_cc_pair_index_attempts( @@ -199,7 +202,7 @@ def get_cc_pair_latest_prune( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), -) -> CCPairPruningTask: +) -> bool: cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, db_session=db_session, @@ -212,24 +215,7 @@ def get_cc_pair_latest_prune( detail="Connection not found for current user's permissions", ) - # look up the last prune task for this connector (if it exists) - pruning_task_name = name_cc_prune_task( - connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id - ) - last_pruning_task = get_latest_task(pruning_task_name, db_session) - if not last_pruning_task: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="No pruning task found.", - ) - - return CCPairPruningTask( - id=last_pruning_task.task_id, - name=last_pruning_task.task_name, - status=last_pruning_task.status, - start_time=last_pruning_task.start_time, - register_time=last_pruning_task.register_time, - ) + return cc_pair_is_pruning(cc_pair.id, db_session) @router.post("/admin/cc-pair/{cc_pair_id}/prune") @@ -238,8 +224,7 @@ def prune_cc_pair( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[list[int]]: - # avoiding circular refs - from danswer.background.celery.tasks.pruning.tasks import prune_documents_task + """Triggers pruning on a particular cc_pair immediately""" cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, @@ -253,26 +238,26 @@ def prune_cc_pair( detail="Connection not found for current user's permissions", ) - pruning_task_name = name_cc_prune_task( - connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id - ) - last_pruning_task = get_latest_task(pruning_task_name, db_session) - if skip_cc_pair_pruning_by_task( - last_pruning_task, - db_session=db_session, - ): + if cc_pair_is_pruning(cc_pair.id, db_session=db_session): raise HTTPException( status_code=HTTPStatus.CONFLICT, detail="Pruning task already in progress.", ) - logger.info(f"Pruning the {cc_pair.connector.name} connector.") - prune_documents_task.apply_async( - kwargs=dict( - connector_id=cc_pair.connector.id, - credential_id=cc_pair.credential.id, - ) + logger.info( + f"Pruning cc_pair: cc_pair_id={cc_pair_id} " + f"connector_id={cc_pair.connector_id} " + f"credential_id={cc_pair.credential_id} " + f"{cc_pair.connector.name} connector." ) + tasks_created = try_creating_prune_generator_task( + cc_pair, db_session, redis_pool.get_client() + ) + if not tasks_created: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Pruning task creation failed.", + ) return StatusResponse( success=True, diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index ee266eca8b8..c49fe8dfd18 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -268,12 +268,14 @@ def from_models( ) -class CCPairPruningTask(BaseModel): - id: str - name: str - status: TaskStatus - start_time: datetime | None - register_time: datetime | None +# Temporarily unused, but can be put back in +# once we store more pruning metadata in redis +# class CCPairPruningTask(BaseModel): +# id: str +# name: str +# status: TaskStatus +# start_time: datetime | None +# register_time: datetime | None class FailedConnectorIndexingStatus(BaseModel): diff --git a/backend/ee/danswer/background/celery/tasks/vespa/tasks.py b/backend/ee/danswer/background/celery/tasks/vespa/tasks.py index d194b2ef9a9..faf2ceffb64 100644 --- a/backend/ee/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/ee/danswer/background/celery/tasks/vespa/tasks.py @@ -15,10 +15,10 @@ def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None: """This function is likely to move in the worker refactor happening next.""" - key = key_bytes.decode("utf-8") - usergroup_id = RedisUserGroup.get_id_from_fence_key(key) + fence_key = key_bytes.decode("utf-8") + usergroup_id = RedisUserGroup.get_id_from_fence_key(fence_key) if not usergroup_id: - task_logger.warning("Could not parse usergroup id from {key}") + task_logger.warning(f"Could not parse usergroup id from {fence_key}") return rug = RedisUserGroup(usergroup_id) diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py index 000bbac59d0..a17b5d7da44 100644 --- a/backend/tests/integration/common_utils/managers/cc_pair.py +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -8,8 +8,6 @@ from danswer.connectors.models import InputType from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus -from danswer.db.enums import TaskStatus -from danswer.server.documents.models import CCPairPruningTask from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorIndexingStatus from danswer.server.documents.models import DocumentSource @@ -247,10 +245,10 @@ def prune( result.raise_for_status() @staticmethod - def get_prune_task( + def is_pruning( cc_pair: DATestCCPair, user_performing_action: DATestUser | None = None, - ) -> CCPairPruningTask: + ) -> bool: response = requests.get( url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune", headers=user_performing_action.headers @@ -258,7 +256,8 @@ def get_prune_task( else GENERAL_HEADERS, ) response.raise_for_status() - return CCPairPruningTask(**response.json()) + response_bool = response.json() + return response_bool @staticmethod def wait_for_prune( @@ -270,16 +269,9 @@ def wait_for_prune( """after: The task register time must be after this time.""" start = time.monotonic() while True: - task = CCPairManager.get_prune_task(cc_pair_test, user_performing_action) - if not task: - raise ValueError("Prune task not found.") - - if not task.register_time or task.register_time < after: - raise ValueError("Prune task register time is too early.") - - if task.status == TaskStatus.SUCCESS: - # Pruning succeeded - return + result = CCPairManager.is_pruning(cc_pair_test, user_performing_action) + if not result: + break elapsed = time.monotonic() - start if elapsed > timeout: diff --git a/backend/tests/integration/tests/connector/test_connector_deletion.py b/backend/tests/integration/tests/connector/test_connector_deletion.py index 46a65f768a9..fc9ea1e0fad 100644 --- a/backend/tests/integration/tests/connector/test_connector_deletion.py +++ b/backend/tests/integration/tests/connector/test_connector_deletion.py @@ -11,6 +11,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.enums import IndexingStatus +from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import create_index_attempt_error from danswer.db.models import IndexAttempt from danswer.db.search_settings import get_current_search_settings @@ -117,6 +118,22 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: user_performing_action=admin_user, ) + # inject an index attempt and index attempt error (exercises foreign key errors) + with Session(get_sqlalchemy_engine()) as db_session: + attempt_id = create_index_attempt( + connector_credential_pair_id=cc_pair_1.id, + search_settings_id=1, + db_session=db_session, + ) + create_index_attempt_error( + index_attempt_id=attempt_id, + batch=1, + docs=[], + exception_msg="", + exception_traceback="", + db_session=db_session, + ) + # Update local records to match the database for later comparison user_group_1.cc_pair_ids = [] user_group_2.cc_pair_ids = [cc_pair_2.id]