diff --git a/run_local.sh b/run_local.sh index 5fe2f2e56..f921bba08 100755 --- a/run_local.sh +++ b/run_local.sh @@ -17,6 +17,8 @@ export DIRACX_CONFIG_BACKEND_URL="git+file://${tmp_dir}/cs_store/initialRepo" export DIRACX_DB_URL_AUTHDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBLOGGINGDB="sqlite+aiosqlite:///:memory:" +export DIRACX_DB_URL_SANDBOXMETADATADB="sqlite+aiosqlite:///:memory:" +export DIRACX_DB_URL_TASKQUEUEDB="sqlite+aiosqlite:///:memory:" export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${tmp_dir}/signing-key/rs256.key" export DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS='["http://'$(hostname| tr -s '[:upper:]' '[:lower:]')':8000/docs/oauth2-redirect"]' diff --git a/setup.cfg b/setup.cfg index f70bd5d1d..45e432ba2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -78,6 +78,7 @@ diracx.dbs = JobDB = diracx.db:JobDB JobLoggingDB = diracx.db:JobLoggingDB SandboxMetadataDB = diracx.db:SandboxMetadataDB + TaskQueueDB = diracx.db:TaskQueueDB #DummyDB = diracx.db:DummyDB diracx.services = jobs = diracx.routers.job_manager:router diff --git a/src/diracx/cli/internal.py b/src/diracx/cli/internal.py index 75b9ab0c8..855a3b7dd 100644 --- a/src/diracx/cli/internal.py +++ b/src/diracx/cli/internal.py @@ -48,9 +48,7 @@ def generate_cs( DefaultGroup=user_group, Users={}, Groups={ - user_group: GroupConfig( - JobShare=None, Properties=["NormalUser"], Quota=None, Users=[] - ) + user_group: GroupConfig(Properties=["NormalUser"], Quota=None, Users=[]) }, ) config = Config( diff --git a/src/diracx/core/config/schema.py b/src/diracx/core/config/schema.py index 8f3470e05..1ee61f6d5 100644 --- a/src/diracx/core/config/schema.py +++ b/src/diracx/core/config/schema.py @@ -49,7 +49,7 @@ class GroupConfig(BaseModel): AutoAddVOMS: bool = False AutoUploadPilotProxy: bool = False AutoUploadProxy: bool = False - JobShare: Optional[int] + JobShare: int = 1000 Properties: list[SecurityProperty] Quota: Optional[int] Users: list[str] @@ -86,9 +86,14 @@ class JobMonitoringConfig(BaseModel): useESForJobParametersFlag: bool = False +class JobSchedulingConfig(BaseModel): + EnableSharesCorrection: bool = False + + class ServicesConfig(BaseModel): Catalogs: dict[str, Any] | None JobMonitoring: JobMonitoringConfig = JobMonitoringConfig() + JobScheduling: JobSchedulingConfig = JobSchedulingConfig() class OperationsConfig(BaseModel): diff --git a/src/diracx/core/models.py b/src/diracx/core/models.py index fa6048011..81753f34e 100644 --- a/src/diracx/core/models.py +++ b/src/diracx/core/models.py @@ -31,7 +31,7 @@ class SortSpec(TypedDict): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str + value: str | int class VectorSearchSpec(TypedDict): diff --git a/src/diracx/db/__init__.py b/src/diracx/db/__init__.py index 3dd13c3c9..e80409d69 100644 --- a/src/diracx/db/__init__.py +++ b/src/diracx/db/__init__.py @@ -1,7 +1,7 @@ -__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB") +__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB") from .auth.db import AuthDB -from .jobs.db import JobDB, JobLoggingDB +from .jobs.db import JobDB, JobLoggingDB, TaskQueueDB from .sandbox_metadata.db import SandboxMetadataDB # from .dummy.db import DummyDB diff --git a/src/diracx/db/jobs/db.py b/src/diracx/db/jobs/db.py index 573bbcc3c..0b0a581d7 100644 --- a/src/diracx/db/jobs/db.py +++ b/src/diracx/db/jobs/db.py @@ -9,16 +9,27 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobStatusReturn, LimitedJobStatusReturn +from diracx.core.properties import JOB_SHARING, SecurityProperty from diracx.core.utils import JobStatus from ..utils import BaseDB, apply_search_filters from .schema import ( + BannedSitesQueue, + GridCEsQueue, InputData, + JobCommands, JobDBBase, JobJDLs, JobLoggingDBBase, Jobs, + JobsQueue, + JobTypesQueue, LoggingInfo, + PlatformsQueue, + SitesQueue, + TagsQueue, + TaskQueueDBBase, + TaskQueues, ) @@ -260,6 +271,37 @@ async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: **dict((await self.conn.execute(stmt)).one()._mapping) ) + async def set_job_command(self, job_id: int, command: str, arguments: str = ""): + """Store a command to be passed to the job together with the next heart beat""" + stmt = insert(JobCommands).values( + JobID=job_id, + Command=command, + Arguments=arguments, + ReceptionTime=datetime.now(tz=timezone.utc), + ) + await self.conn.execute(stmt) + + async def get_vo(self, job_id: int) -> str: + """ + Get the VO of the owner of the job + """ + # TODO: Consider having a VO column in the Jobs table + from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd + + # TODO: this is going to be problematic + stmt = select(JobJDLs.JDL).where(JobJDLs.JobID == job_id) + jdl = (await self.conn.execute(stmt)).scalar_one() + if not jdl.startswith("["): + jdl = f"[{jdl}]" + return ClassAd(jdl).getAttributeString("VirtualOrganisation") + + async def delete_job(self, job_id: int): + """ + Delete a job from the database + """ + stmt = delete(JobJDLs).where(JobJDLs.JobID == job_id) + await self.conn.execute(stmt) + MAGIC_EPOC_NUMBER = 1270000000 @@ -405,3 +447,298 @@ async def get_wms_time_stamps(self, job_id): result[event] = str(etime + MAGIC_EPOC_NUMBER) return result + + +class TaskQueueDB(BaseDB): + metadata = TaskQueueDBBase.metadata + + async def get_tq_id_for_job(self, job_id: int) -> int: + """ + Get the task queue info for a given job + """ + stmt = select(TaskQueues.TQId).where(JobsQueue.JobId == job_id) + return (await self.conn.execute(stmt)).scalar_one() + + async def get_owner_for_task_queue(self, tq_id: int) -> dict[str, str]: + """ + Get the owner and owner group for a task queue + """ + stmt = select(TaskQueues.Owner, TaskQueues.OwnerGroup, TaskQueues.VO).where( + TaskQueues.TQId == tq_id + ) + return dict((await self.conn.execute(stmt)).one()._mapping) + + async def remove_job(self, job_id: int): + """ + Remove a job from the task queues + """ + stmt = delete(JobsQueue).where(JobsQueue.JobId == job_id) + await self.conn.execute(stmt) + + async def delete_task_queue_if_empty( + self, + tq_id: int, + tq_owner: str, + tq_group: str, + job_share: int, + group_properties: list[SecurityProperty], + enable_shares_correction: bool, + allow_background_tqs: bool, + ): + """ + Try to delete a task queue if it's empty + """ + # Check if the task queue is empty + stmt = ( + select(TaskQueues.TQId) + .where(TaskQueues.Enabled >= 1) + .where(TaskQueues.TQId == tq_id) + .where(~TaskQueues.TQId.in_(select(JobsQueue.TQId))) + ) + rows = await self.conn.execute(stmt) + if not rows.rowcount: + return + + # Deleting the task queue (the other tables will be deleted in cascade) + stmt = delete(TaskQueues).where(TaskQueues.TQId == tq_id) + await self.conn.execute(stmt) + + await self.recalculate_tq_shares_for_entity( + tq_owner, + tq_group, + job_share, + group_properties, + enable_shares_correction, + allow_background_tqs, + ) + + async def recalculate_tq_shares_for_entity( + self, + owner: str, + group: str, + job_share: int, + group_properties: list[SecurityProperty], + enable_shares_correction: bool, + allow_background_tqs: bool, + ): + """ + Recalculate the shares for a user/userGroup combo + """ + if JOB_SHARING in group_properties: + # If group has JobSharing just set prio for that entry, user is irrelevant + return await self.__set_priorities_for_entity( + owner, group, job_share, group_properties, allow_background_tqs + ) + + stmt = ( + select(TaskQueues.Owner, func.count(TaskQueues.Owner)) + .where(TaskQueues.OwnerGroup == group) + .group_by(TaskQueues.Owner) + ) + rows = await self.conn.execute(stmt) + # make the rows a list of tuples + # Get owners in this group and the amount of times they appear + # TODO: I guess the rows are already a list of tupes + # maybe refactor + data = [(r[0], r[1]) for r in rows if r] + numOwners = len(data) + # If there are no owners do now + if numOwners == 0: + return + # Split the share amongst the number of owners + entities_shares = {row[0]: job_share / numOwners for row in data} + + # TODO: implement the following + # If corrector is enabled let it work it's magic + # if enable_shares_correction: + # entities_shares = await self.__shares_corrector.correct_shares( + # entitiesShares, group=group + # ) + + # Keep updating + owners = dict(data) + # IF the user is already known and has more than 1 tq, the rest of the users don't need to be modified + # (The number of owners didn't change) + if owner in owners and owners[owner] > 1: + await self.__set_priorities_for_entity( + owner, + group, + entities_shares[owner], + group_properties, + allow_background_tqs, + ) + return + # Oops the number of owners may have changed so we recalculate the prio for all owners in the group + for owner in owners: + await self.__set_priorities_for_entity( + owner, + group, + entities_shares[owner], + group_properties, + allow_background_tqs, + ) + pass + + async def __set_priorities_for_entity( + self, + owner: str, + group: str, + share, + properties: list[SecurityProperty], + allow_background_tqs: bool, + ): + """ + Set the priority for a user/userGroup combo given a splitted share + """ + from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import ( + TQ_MIN_SHARE, + priorityIgnoredFields, + ) + + stmt = ( + select( + TaskQueues.TQId, + func.sum(JobsQueue.RealPriority) / func.count(JobsQueue.RealPriority), + ) + # TODO: uncomment me and understand why mypy is unhappy with join here and not elsewhere + # .select_from(TaskQueues.join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId)) + .where(TaskQueues.OwnerGroup == group).group_by(TaskQueues.TQId) + ) + if JOB_SHARING not in properties: + stmt = stmt.where(TaskQueues.Owner == owner) + rows = await self.conn.execute(stmt) + # Make a dict of TQId:priority + tqDict: dict[int, float] = {row[0]: row[1] for row in rows} + + if not tqDict: + return + + allowBgTQs = allow_background_tqs + + # TODO: one of the only place the logic could actually be encapsulated + # so refactor + + # Calculate Sum of priorities + totalPrio = 0.0 + for k in tqDict: + if tqDict[k] > 0.1 or not allowBgTQs: + totalPrio += tqDict[k] + # Update prio for each TQ + for tqId in tqDict: + if tqDict[tqId] > 0.1 or not allowBgTQs: + prio = (share / totalPrio) * tqDict[tqId] + else: + prio = TQ_MIN_SHARE + prio = max(prio, TQ_MIN_SHARE) + tqDict[tqId] = prio + + # Generate groups of TQs that will have the same prio=sum(prios) maomenos + rows = await self.retrieve_task_queues(list(tqDict)) + # TODO: check the following asumption is correct + allTQsData = rows + tqGroups: dict[str, list] = {} + for tqid in allTQsData: + tqData = allTQsData[tqid] + for field in ("Jobs", "Priority") + priorityIgnoredFields: + if field in tqData: + tqData.pop(field) + tqHash = [] + for f in sorted(tqData): + tqHash.append(f"{f}:{tqData[f]}") + tqHash = "|".join(tqHash) + if tqHash not in tqGroups: + tqGroups[tqHash] = [] + tqGroups[tqHash].append(tqid) + groups = [tqGroups[td] for td in tqGroups] + + # Do the grouping + for tqGroup in groups: + totalPrio = 0 + if len(tqGroup) < 2: + continue + for tqid in tqGroup: + totalPrio += tqDict[tqid] + for tqid in tqGroup: + tqDict[tqid] = totalPrio + + # Group by priorities + prioDict: dict[int, list] = {} + for tqId in tqDict: + prio = tqDict[tqId] + if prio not in prioDict: + prioDict[prio] = [] + prioDict[prio].append(tqId) + + # Execute updates + for prio, tqs in prioDict.items(): + update_stmt = ( + update(TaskQueues).where(TaskQueues.TQId.in_(tqs)).values(Priority=prio) + ) + await self.conn.execute(update_stmt) + + async def retrieve_task_queues(self, tqIdList=None): + """ + Get all the task queues + """ + if tqIdList is not None and not tqIdList: + # Empty list => Fast-track no matches + return {} + + stmt = ( + select( + TaskQueues.TQId, + TaskQueues.Priority, + func.count(JobsQueue.TQId).label("Jobs"), + TaskQueues.Owner, + TaskQueues.OwnerGroup, + TaskQueues.VO, + TaskQueues.CPUTime, + ) + .select_from(TaskQueues.join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId)) + .select_from( + TaskQueues.join(SitesQueue, TaskQueues.TQId == SitesQueue.TQId) + ) + .select_from( + TaskQueues.join(GridCEsQueue, TaskQueues.TQId == GridCEsQueue.TQId) + ) + .group_by( + TaskQueues.TQId, + TaskQueues.Priority, + TaskQueues.Owner, + TaskQueues.OwnerGroup, + TaskQueues.VO, + TaskQueues.CPUTime, + ) + ) + if tqIdList is not None: + stmt = stmt.where(TaskQueues.TQId.in_(tqIdList)) + + tqData = dict(row._mapping for row in await self.conn.execute(stmt)) + # TODO: the line above should be equivalent to the following commented code, check this is the case + # for record in rows: + # tqId = record[0] + # tqData[tqId] = { + # "Priority": record[1], + # "Jobs": record[2], + # "Owner": record[3], + # "OwnerGroup": record[4], + # "VO": record[5], + # "CPUTime": record[6], + # } + + for tqId in tqData: + # TODO: maybe factorize this handy tuple list + for table, field in { + (SitesQueue, "Sites"), + (GridCEsQueue, "GridCEs"), + (BannedSitesQueue, "BannedSites"), + (PlatformsQueue, "Platforms"), + (JobTypesQueue, "JobTypes"), + (TagsQueue, "Tags"), + }: + stmt = select(table.Value).where(table.TQId == tqId) + tqData[tqId][field] = list( + row[0] for row in await self.conn.execute(stmt) + ) + + return tqData diff --git a/src/diracx/db/jobs/schema.py b/src/diracx/db/jobs/schema.py index fe76eedd2..e1b6625bc 100644 --- a/src/diracx/db/jobs/schema.py +++ b/src/diracx/db/jobs/schema.py @@ -1,9 +1,11 @@ import sqlalchemy.types as types from sqlalchemy import ( + BigInteger, + Boolean, DateTime, Enum, + Float, ForeignKey, - ForeignKeyConstraint, Index, Integer, Numeric, @@ -17,6 +19,7 @@ JobDBBase = declarative_base() JobLoggingDBBase = declarative_base() +TaskQueueDBBase = declarative_base() class EnumBackedBool(types.TypeDecorator): @@ -45,19 +48,16 @@ def process_result_value(self, value, dialect) -> bool: raise NotImplementedError(f"Unknown {value=}") -class JobJDLs(JobDBBase): - __tablename__ = "JobJDLs" - JobID = Column(Integer, autoincrement=True) - JDL = Column(Text) - JobRequirements = Column(Text) - OriginalJDL = Column(Text) - __table_args__ = (PrimaryKeyConstraint("JobID"),) - - class Jobs(JobDBBase): __tablename__ = "Jobs" - JobID = Column("JobID", Integer, primary_key=True, default=0) + JobID = Column( + "JobID", + Integer, + ForeignKey("JobJDLs.JobID", ondelete="CASCADE"), + primary_key=True, + default=0, + ) JobType = Column("JobType", String(32), default="user") DIRACSetup = Column("DIRACSetup", String(32), default="test") JobGroup = Column("JobGroup", String(32), default="00000000") @@ -96,7 +96,6 @@ class Jobs(JobDBBase): ) __table_args__ = ( - ForeignKeyConstraint(["JobID"], ["JobJDLs.JobID"]), Index("JobType", "JobType"), Index("JobGroup", "JobGroup"), Index("JobSplitType", "JobSplitType"), @@ -111,33 +110,46 @@ class Jobs(JobDBBase): ) +class JobJDLs(JobDBBase): + __tablename__ = "JobJDLs" + JobID = Column(Integer, autoincrement=True, primary_key=True) + JDL = Column(Text) + JobRequirements = Column(Text) + OriginalJDL = Column(Text) + + class InputData(JobDBBase): __tablename__ = "InputData" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) LFN = Column(String(255), default="", primary_key=True) Status = Column(String(32), default="AprioriGood") - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) class JobParameters(JobDBBase): __tablename__ = "JobParameters" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) class OptimizerParameters(JobDBBase): __tablename__ = "OptimizerParameters" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) class AtticJobParameters(JobDBBase): __tablename__ = "AtticJobParameters" - JobID = Column(Integer, ForeignKey("Jobs.JobID"), primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) RescheduleCycle = Column(Integer) @@ -163,25 +175,25 @@ class SiteMaskLogging(JobDBBase): class HeartBeatLoggingInfo(JobDBBase): __tablename__ = "HeartBeatLoggingInfo" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) HeartBeatTime = Column(DateTime, primary_key=True) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) - class JobCommands(JobDBBase): __tablename__ = "JobCommands" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Command = Column(String(100)) Arguments = Column(String(100)) Status = Column(String(64), default="Received") ReceptionTime = Column(DateTime, primary_key=True) ExecutionTime = NullColumn(DateTime) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) - class LoggingInfo(JobLoggingDBBase): __tablename__ = "LoggingInfo" @@ -195,3 +207,99 @@ class LoggingInfo(JobLoggingDBBase): StatusTimeOrder = Column(Numeric(precision=12, scale=3), default=0) StatusSource = Column(String(32), default="Unknown") __table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),) + + +class TaskQueues(TaskQueueDBBase): + __tablename__ = "tq_TaskQueues" + TQId = Column(Integer, primary_key=True) + Owner = Column(String(255), nullable=False) + OwnerDN = Column(String(255)) + OwnerGroup = Column(String(32), nullable=False) + VO = Column(String(32), nullable=False) + CPUTime = Column(BigInteger, nullable=False) + Priority = Column(Float, nullable=False) + Enabled = Column(Boolean, nullable=False, default=0) + __table_args__ = (Index("TQOwner", "Owner", "OwnerGroup", "CPUTime"),) + + +class JobsQueue(TaskQueueDBBase): + __tablename__ = "tq_Jobs" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + JobId = Column(Integer, primary_key=True) + Priority = Column(Integer, nullable=False) + RealPriority = Column(Float, nullable=False) + __table_args__ = (Index("TaskIndex", "TQId"),) + + +class SitesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToSites" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("SitesTaskIndex", "TQId"), + Index("SitesIndex", "Value"), + ) + + +class GridCEsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToGridCEs" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("GridCEsTaskIndex", "TQId"), + Index("GridCEsValueIndex", "Value"), + ) + + +class BannedSitesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToBannedSites" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("BannedSitesTaskIndex", "TQId"), + Index("BannedSitesValueIndex", "Value"), + ) + + +class PlatformsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToPlatforms" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("PlatformsTaskIndex", "TQId"), + Index("PlatformsValueIndex", "Value"), + ) + + +class JobTypesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToJobTypes" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("JobTypesTaskIndex", "TQId"), + Index("JobTypesValueIndex", "Value"), + ) + + +class TagsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToTags" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("TagsTaskIndex", "TQId"), + Index("TagsValueIndex", "Value"), + ) diff --git a/src/diracx/db/jobs/status_utility.py b/src/diracx/db/jobs/status_utility.py index c83c05e04..4116981d6 100644 --- a/src/diracx/db/jobs/status_utility.py +++ b/src/diracx/db/jobs/status_utility.py @@ -1,15 +1,19 @@ from datetime import datetime, timezone from unittest.mock import MagicMock +from fastapi import BackgroundTasks from sqlalchemy.exc import NoResultFound +from diracx.core.config.schema import Config from diracx.core.models import ( JobStatusUpdate, ScalarSearchOperator, + ScalarSearchSpec, SetJobStatusReturn, ) from diracx.core.utils import JobStatus -from diracx.db.jobs.db import JobDB, JobLoggingDB +from diracx.db.jobs.db import JobDB, JobLoggingDB, TaskQueueDB +from diracx.routers.dependencies import SandboxMetadataDB async def set_job_status( @@ -146,3 +150,195 @@ async def set_job_status( ) return SetJobStatusReturn(**job_data) + + +async def delete_job( + job_id: int, + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Delete a job by killing and setting the job status to DELETED. + """ + + await __kill_delete_job( + job_id, + config, + job_db, + task_queue_db, + background_task, + ) + + await set_job_status( + job_id, + { + datetime.now(timezone.utc): JobStatusUpdate( + Status=JobStatus.DELETED, + MinorStatus="Checking accounting", + StatusSource="job_manager", + ) + }, + job_db, + job_logging_db, + force=True, + ) + + +async def kill_job( + job_id: int, + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + await __kill_delete_job( + job_id, + config, + job_db, + task_queue_db, + background_task, + ) + + await set_job_status( + job_id, + { + datetime.now(timezone.utc): JobStatusUpdate( + Status=JobStatus.KILLED, + MinorStatus="Marked for termination", + StatusSource="job_manager", + ) + }, + job_db, + job_logging_db, + force=True, + ) + + +async def __kill_delete_job( + job_id: int, + config: Config, + job_db: JobDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise + from DIRAC.StorageManagementSystem.Client.StorageManagerClient import ( + StorageManagerClient, + ) + + res = await job_db.search( + parameters=["Status", "Owner", "OwnerGroup"], + search=[ + ScalarSearchSpec( + parameter="JobID", + operator=ScalarSearchOperator.EQUAL, + value=job_id, + ), + ], + sorts=[], + ) + if not res: + raise NoResultFound(f"Job {job_id} not found") + + status = res[0]["Status"] + owner = res[0]["Owner"] + owner_group = res[0]["OwnerGroup"] + vo = await job_db.get_vo(job_id) + + if status == JobStatus.STAGING: + returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID([job_id])) + + if status in (JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED): + await job_db.set_job_command(job_id, "Kill") + + # Delete the job from the task queue + try: + tq_id = await task_queue_db.get_tq_id_for_job(job_id) + await task_queue_db.remove_job(job_id) + background_task.add_task( + task_queue_db.delete_task_queue_if_empty, + tq_id, + owner, + owner_group, + config.Registry[vo].Groups[owner_group].JobShare, + config.Registry[vo].Groups[owner_group].Properties, + config.Operations[vo].Services.JobScheduling.EnableSharesCorrection, + config.Registry[vo].Groups[owner_group].AllowBackgroundTQs, + ) + except NoResultFound: + pass + + +async def remove_job( + job_id: int, + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Fully remove a job from the WMS databases. + """ + from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise + from DIRAC.StorageManagementSystem.Client.StorageManagerClient import ( + StorageManagerClient, + ) + + res = await job_db.search( + parameters=["Status", "Owner", "OwnerGroup"], + search=[ + ScalarSearchSpec( + parameter="JobID", + operator=ScalarSearchOperator.EQUAL, + value=job_id, + ), + ], + sorts=[], + ) + if not res: + raise NoResultFound(f"Job {job_id} not found") + + status = res[0]["Status"] + owner = res[0]["Owner"] + owner_group = res[0]["OwnerGroup"] + vo = await job_db.get_vo(job_id) + + # Remove the staging task from the StorageManager + # TODO: this was not done in the JobManagerHandler, but it was done in the kill method + # I think it should be done here too + if status == JobStatus.STAGING: + returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID([job_id])) + + # Remove the job from SandboxMetadataDB + # TODO: this was also not done in the JobManagerHandler, but it was done in the JobCleaningAgent + # I think it should be done here as well + await sandbox_metadata_db.unassign_sandbox_from_job(job_id) + + # Remove the job from TaskQueueDB + try: + tq_id = await task_queue_db.get_tq_id_for_job(job_id) + await task_queue_db.remove_job(job_id) + background_task.add_task( + task_queue_db.delete_task_queue_if_empty, + tq_id, + owner, + owner_group, + config.Registry[vo].Groups[owner_group].JobShare, + config.Registry[vo].Groups[owner_group].Properties, + config.Operations[vo].Services.JobScheduling.EnableSharesCorrection, + config.Registry[vo].Groups[owner_group].AllowBackgroundTQs, + ) + except NoResultFound: + pass + + # Remove the job from JobLoggingDB + await job_logging_db.delete_records(job_id) + + # Remove the job from JobDB + await job_db.delete_job(job_id) diff --git a/src/diracx/db/sandbox_metadata/db.py b/src/diracx/db/sandbox_metadata/db.py index 384b9a142..425aa4ec4 100644 --- a/src/diracx/db/sandbox_metadata/db.py +++ b/src/diracx/db/sandbox_metadata/db.py @@ -6,10 +6,11 @@ import datetime import sqlalchemy +from sqlalchemy import delete from ..utils import BaseDB from .schema import Base as SandboxMetadataDBBase -from .schema import sb_Owners, sb_SandBoxes +from .schema import sb_EntityMapping, sb_Owners, sb_SandBoxes class SandboxMetadataDB(BaseDB): @@ -36,7 +37,7 @@ async def _get_put_owner(self, owner: str, owner_group: str) -> int: async def insert( self, owner: str, owner_group: str, sb_SE: str, se_PFN: str, size: int = 0 - ) -> tuple[int, bool]: + ) -> int: """inserts a new sandbox in SandboxMetadataDB this is "equivalent" of DIRAC registerAndGetSandbox @@ -78,3 +79,15 @@ async def delete(self, sandbox_ids: list[int]) -> bool: await self.conn.execute(stmt) return True + + async def unassign_sandbox_from_job( + self, + job_id: int, + ): + """ + Unassign sandbox from job + """ + stmt = delete(sb_EntityMapping).where( + sb_EntityMapping.EntityId == f"Job:{job_id}" + ) + await self.conn.execute(stmt) diff --git a/src/diracx/db/utils.py b/src/diracx/db/utils.py index 93258eb4e..01466c978 100644 --- a/src/diracx/db/utils.py +++ b/src/diracx/db/utils.py @@ -126,6 +126,9 @@ async def engine_context(self) -> AsyncIterator[None]: echo=True, ) async with engine.begin() as conn: + # set PRAGMA foreign_keys=ON if sqlite + if self._db_url.startswith("sqlite"): + await conn.exec_driver_sql("PRAGMA foreign_keys=ON") await conn.run_sync(self.metadata.create_all) self._engine = engine diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py index 17d4cf830..e42e7fa6c 100644 --- a/src/diracx/routers/dependencies.py +++ b/src/diracx/routers/dependencies.py @@ -5,6 +5,8 @@ "AuthDB", "JobDB", "JobLoggingDB", + "SandboxMetadataDB", + "TaskQueueDB", "add_settings_annotation", "AvailableSecurityProperties", ) @@ -19,6 +21,8 @@ from diracx.db import AuthDB as _AuthDB from diracx.db import JobDB as _JobDB from diracx.db import JobLoggingDB as _JobLoggingDB +from diracx.db import SandboxMetadataDB as _SandboxMetadataDB +from diracx.db import TaskQueueDB as _TaskQueueDB T = TypeVar("T") @@ -32,6 +36,10 @@ def add_settings_annotation(cls: T) -> T: AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] +SandboxMetadataDB = Annotated[ + _SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction) +] +TaskQueueDB = Annotated[_TaskQueueDB, Depends(_TaskQueueDB.transaction)] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py index 2f06a4756..f52b904ff 100644 --- a/src/diracx/routers/job_manager/__init__.py +++ b/src/diracx/routers/job_manager/__init__.py @@ -6,7 +6,7 @@ from http import HTTPStatus from typing import Annotated, Any, TypedDict -from fastapi import Body, Depends, HTTPException, Query +from fastapi import BackgroundTasks, Body, Depends, HTTPException, Query from pydantic import BaseModel, root_validator from sqlalchemy.exc import NoResultFound @@ -23,11 +23,14 @@ from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.core.utils import JobStatus from diracx.db.jobs.status_utility import ( + delete_job, + kill_job, + remove_job, set_job_status, ) from ..auth import UserInfo, has_properties, verify_dirac_access_token -from ..dependencies import JobDB, JobLoggingDB +from ..dependencies import JobDB, JobLoggingDB, SandboxMetadataDB, TaskQueueDB from ..fastapi_classes import DiracxRouter MAX_PARAMETRIC_JOBS = 20 @@ -233,14 +236,101 @@ def __init__(self, user_info: UserInfo, allInfo: bool = True): @router.delete("/") -async def delete_bulk_jobs(job_ids: Annotated[list[int], Query()]): +async def delete_bulk_jobs( + job_ids: Annotated[list[int], Query()], + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + # TODO: implement job policy + try: + await asyncio.gather( + *( + delete_job( + job_id, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + ) + for job_id in job_ids + ) + ) + except NoResultFound as e: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e + return job_ids @router.post("/kill") async def kill_bulk_jobs( job_ids: Annotated[list[int], Query()], + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, ): + try: + await asyncio.gather( + *( + kill_job( + job_id, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + ) + for job_id in job_ids + ) + ) + except NoResultFound as e: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e + + return job_ids + + +@router.post("/remove") +async def remove_bulk_jobs( + job_ids: Annotated[list[int], Query()], + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Fully remove a list of jobs from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead for any other purpose. + """ + # TODO: implement job policy + + try: + await asyncio.gather( + *( + remove_job( + job_id, + config, + job_db, + job_logging_db, + sandbox_metadata_db, + task_queue_db, + background_task, + ) + for job_id in job_ids + ) + ) + except NoResultFound as e: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e + return job_ids @@ -406,6 +496,98 @@ async def get_single_job(job_id: int): return f"This job {job_id}" +@router.delete("/{job_id}") +async def delete_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Delete a job by killing and setting the job status to DELETED. + """ + + # TODO: implement job policy + try: + await delete_job( + job_id, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + ) + except NoResultFound as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND.value, detail=str(e) + ) from e + + return f"Job {job_id} has been successfully deleted" + + +@router.post("/{job_id}/kill") +async def kill_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Kill a job. + """ + + # TODO: implement job policy + + try: + await kill_job( + job_id, config, job_db, job_logging_db, task_queue_db, background_task + ) + except NoResultFound as e: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e + + return f"Job {job_id} has been successfully killed" + + +@router.post("/{job_id}/remove") +async def remove_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Fully remove a job from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead. + """ + + # TODO: implement job policy + + try: + await remove_job( + job_id, + config, + job_db, + job_logging_db, + sandbox_metadata_db, + task_queue_db, + background_task, + ) + except NoResultFound as e: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e + + return f"Job {job_id} has been successfully removed" + + @router.get("/{job_id}/status") async def get_single_job_status( job_id: int, job_db: JobDB diff --git a/tests/conftest.py b/tests/conftest.py index 24eacb6b9..000c03add 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,6 +37,7 @@ def pytest_collection_modifyitems(config, items): # --regenerate-client given in cli: allow client re-generation return skip_regen = pytest.mark.skip(reason="need --regenerate-client option to run") + found = False for item in items: if item.name == "test_regenerate_client": item.add_marker(skip_regen) @@ -74,6 +75,7 @@ def with_app(test_auth_settings, with_config_repo): database_urls={ "JobDB": "sqlite+aiosqlite:///:memory:", "JobLoggingDB": "sqlite+aiosqlite:///:memory:", + "TaskQueueDB": "sqlite+aiosqlite:///:memory:", "AuthDB": "sqlite+aiosqlite:///:memory:", "SandboxMetadataDB": "sqlite+aiosqlite:///:memory:", }, diff --git a/tests/routers/test_job_manager.py b/tests/routers/test_job_manager.py index f42a1f1d9..51660d52a 100644 --- a/tests/routers/test_job_manager.py +++ b/tests/routers/test_job_manager.py @@ -589,3 +589,60 @@ def test_set_job_status_with_invalid_job_id(normal_user_client: TestClient): # Assert assert r.status_code == 404, r.json() assert r.json() == {"detail": "Job 999999999 not found"} + + +def test_delete_job(normal_user_client: TestClient): + # Arrange + job_definitions = [TEST_JDL] + r = normal_user_client.post("/jobs/", json=job_definitions) + assert r.status_code == 200, r.json() + assert len(r.json()) == 1 + job_id = r.json()[0]["JobID"] + + # Act + r = normal_user_client.delete(f"/jobs/{job_id}") + + # Assert + assert r.status_code == 200, r.json() + r = normal_user_client.get(f"/jobs/{job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(job_id)]["Status"] == JobStatus.DELETED + assert r.json()[str(job_id)]["MinorStatus"] == "Checking accounting" + assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + + +def test_kill_job(normal_user_client: TestClient): + # Arrange + job_definitions = [TEST_JDL] + r = normal_user_client.post("/jobs/", json=job_definitions) + assert r.status_code == 200, r.json() + assert len(r.json()) == 1 + job_id = r.json()[0]["JobID"] + + # Act + r = normal_user_client.post(f"/jobs/{job_id}/kill") + + # Assert + assert r.status_code == 200, r.json() + r = normal_user_client.get(f"/jobs/{job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(job_id)]["Status"] == JobStatus.KILLED + assert r.json()[str(job_id)]["MinorStatus"] == "Marked for termination" + assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + + +def test_remove_job(normal_user_client: TestClient): + # Arrange + job_definitions = [TEST_JDL] + r = normal_user_client.post("/jobs/", json=job_definitions) + assert r.status_code == 200, r.json() + assert len(r.json()) == 1 + job_id = r.json()[0]["JobID"] + + # Act + r = normal_user_client.post(f"/jobs/{job_id}/remove") + + # Assert + assert r.status_code == 200, r.json() + r = normal_user_client.get(f"/jobs/{job_id}/status") + assert r.status_code == 404, r.json()