diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index 4431d2f4..46450045 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -21,7 +21,7 @@ from cachetools import Cache, LRUCache, TTLCache, cachedmethod from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints -from ..exceptions import BadConfigurationVersion +from ..exceptions import BadConfigurationVersionError from ..extensions import select_from_extension from .schema import Config @@ -159,7 +159,7 @@ def latest_revision(self) -> tuple[str, datetime]: ).strip() modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc) except sh.ErrorReturnCode as e: - raise BadConfigurationVersion(f"Error parsing latest revision: {e}") from e + raise BadConfigurationVersionError(f"Error parsing latest revision: {e}") from e logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified) return rev, modified @@ -176,7 +176,7 @@ def read_raw(self, hexsha: str, modified: datetime) -> Config: ) raw_obj = yaml.safe_load(blob) except sh.ErrorReturnCode as e: - raise BadConfigurationVersion(f"Error reading configuration: {e}") from e + raise BadConfigurationVersionError(f"Error reading configuration: {e}") from e config_class: Config = select_from_extension(group="diracx", name="config")[ 0 diff --git a/diracx-core/src/diracx/core/config/schema.py b/diracx-core/src/diracx/core/config/schema.py index 92d623da..8da2837a 100644 --- a/diracx-core/src/diracx/core/config/schema.py +++ b/diracx-core/src/diracx/core/config/schema.py @@ -115,7 +115,6 @@ class DIRACConfig(BaseModel): class JobMonitoringConfig(BaseModel): GlobalJobsInfo: bool = True - useESForJobParametersFlag: bool = False class JobSchedulingConfig(BaseModel): diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 3338f3b1..79834b1c 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -1,7 +1,7 @@ from http import HTTPStatus -class DiracHttpResponse(RuntimeError): +class DiracHttpResponseError(RuntimeError): def __init__(self, status_code: int, data): self.status_code = status_code self.data = data @@ -30,7 +30,7 @@ class ConfigurationError(DiracError): """Used whenever we encounter a problem with the configuration.""" -class BadConfigurationVersion(ConfigurationError): +class BadConfigurationVersionError(ConfigurationError): """The requested version is not known.""" @@ -38,7 +38,7 @@ class InvalidQueryError(DiracError): """It was not possible to build a valid database query from the given input.""" -class JobNotFound(Exception): +class JobNotFoundError(Exception): def __init__(self, job_id: int, detail: str | None = None): self.job_id: int = job_id super().__init__(f"Job {job_id} not found" + (" ({detail})" if detail else "")) diff --git a/diracx-db/src/diracx/db/exceptions.py b/diracx-db/src/diracx/db/exceptions.py index ca0cf0ec..0a163f92 100644 --- a/diracx-db/src/diracx/db/exceptions.py +++ b/diracx-db/src/diracx/db/exceptions.py @@ -1,2 +1,2 @@ -class DBUnavailable(Exception): +class DBUnavailableError(Exception): pass diff --git a/diracx-db/src/diracx/db/os/utils.py b/diracx-db/src/diracx/db/os/utils.py index 8b611c00..431cceaa 100644 --- a/diracx-db/src/diracx/db/os/utils.py +++ b/diracx-db/src/diracx/db/os/utils.py @@ -16,7 +16,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension -from diracx.db.exceptions import DBUnavailable +from diracx.db.exceptions import DBUnavailableError logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class OpenSearchDBError(Exception): pass -class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError): +class OpenSearchDBUnavailableError(DBUnavailableError, OpenSearchDBError): pass @@ -152,7 +152,7 @@ async def ping(self): be ran at every query. """ if not await self.client.ping(): - raise OpenSearchDBUnavailable( + raise OpenSearchDBUnavailableError( f"Failed to connect to {self.__class__.__qualname__}" ) diff --git a/diracx-db/src/diracx/db/sql/dummy/db.py b/diracx-db/src/diracx/db/sql/dummy/db.py index 9a033163..fa6bd8f1 100644 --- a/diracx-db/src/diracx/db/sql/dummy/db.py +++ b/diracx-db/src/diracx/db/sql/dummy/db.py @@ -25,7 +25,7 @@ class DummyDB(BaseSQLDB): async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = [Cars.__table__.columns[x] for x in group_by] - stmt = select(*columns, func.count(Cars.licensePlate).label("count")) + stmt = select(*columns, func.count(Cars.license_plate).label("count")) stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search) stmt = stmt.group_by(*columns) @@ -44,7 +44,7 @@ async def insert_owner(self, name: str) -> int: async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int: stmt = insert(Cars).values( - licensePlate=license_plate, model=model, ownerID=owner_id + license_plate=license_plate, model=model, owner_id=owner_id ) result = await self.conn.execute(stmt) diff --git a/diracx-db/src/diracx/db/sql/dummy/schema.py b/diracx-db/src/diracx/db/sql/dummy/schema.py index ebb37b8d..b6ddde79 100644 --- a/diracx-db/src/diracx/db/sql/dummy/schema.py +++ b/diracx-db/src/diracx/db/sql/dummy/schema.py @@ -10,13 +10,13 @@ class Owners(Base): __tablename__ = "Owners" - ownerID = Column(Integer, primary_key=True, autoincrement=True) + owner_id = Column(Integer, primary_key=True, autoincrement=True) creation_time = DateNowColumn() name = Column(String(255)) class Cars(Base): __tablename__ = "Cars" - licensePlate = Column(Uuid(), primary_key=True) + license_plate = Column(Uuid(), primary_key=True) model = Column(String(255)) - ownerID = Column(Integer, ForeignKey(Owners.ownerID)) + owner_id = Column(Integer, ForeignKey(Owners.owner_id)) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 7817bb39..7bb1db0b 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -8,7 +8,8 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import BindParameter -from diracx.core.exceptions import InvalidQueryError, JobNotFound + +from diracx.core.exceptions import InvalidQueryError, JobNotFoundError from diracx.core.models import ( LimitedJobStatusReturn, SearchSpec, @@ -42,7 +43,7 @@ class JobDB(BaseSQLDB): # TODO: this is copied from the DIRAC JobDB # but is overwriten in LHCbDIRAC, so we need # to find a way to make it dynamic - jdl2DBParameters = ["JobName", "JobType", "JobGroup"] + jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"] async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = _get_columns(Jobs.__table__, group_by) @@ -110,11 +111,11 @@ async def insert_input_data(self, lfns: dict[int, list[str]]): ], ) - async def setJobAttributes(self, job_id, jobData): + async def set_job_attributes(self, job_id, job_data): """TODO: add myDate and force parameters.""" - if "Status" in jobData: - jobData = jobData | {"LastUpdateTime": datetime.now(tz=timezone.utc)} - stmt = update(Jobs).where(Jobs.JobID == job_id).values(jobData) + if "Status" in job_data: + job_data = job_data | {"LastUpdateTime": datetime.now(tz=timezone.utc)} + stmt = update(Jobs).where(Jobs.JobID == job_id).values(job_data) await self.conn.execute(stmt) async def create_job(self, original_jdl): @@ -159,9 +160,9 @@ async def update_job_jdls(self, jdls_to_update: dict[int, str]): ], ) - async def checkAndPrepareJob( + async def check_and_prepare_job( self, - jobID, + job_id, class_ad_job, class_ad_req, owner, @@ -178,8 +179,8 @@ async def checkAndPrepareJob( checkAndPrepareJob, ) - retVal = checkAndPrepareJob( - jobID, + ret_val = checkAndPrepareJob( + job_id, class_ad_job, class_ad_req, owner, @@ -188,13 +189,13 @@ async def checkAndPrepareJob( vo, ) - if not retVal["OK"]: - if cmpError(retVal, EWMSSUBM): - await self.setJobAttributes(jobID, job_attrs) + if not ret_val["OK"]: + if cmpError(ret_val, EWMSSUBM): + await self.set_job_attributes(job_id, job_attrs) - returnValueOrRaise(retVal) + returnValueOrRaise(ret_val) - async def setJobJDL(self, job_id, jdl): + async def set_job_jdl(self, job_id, jdl): from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL stmt = ( @@ -202,7 +203,7 @@ async def setJobJDL(self, job_id, jdl): ) await self.conn.execute(stmt) - async def setJobJDLsBulk(self, jdls): + async def set_job_jdl_bulk(self, jdls): from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL await self.conn.execute( @@ -212,19 +213,19 @@ async def setJobJDLsBulk(self, jdls): [{"b_JobID": jid, "JDL": compressJDL(jdl)} for jid, jdl in jdls.items()], ) - async def setJobAttributesBulk(self, jobData): + async def set_job_attributes_bulk(self, job_data): """TODO: add myDate and force parameters.""" - for job_id in jobData.keys(): - if "Status" in jobData[job_id]: - jobData[job_id].update( + for job_id in job_data.keys(): + if "Status" in job_data[job_id]: + job_data[job_id].update( {"LastUpdateTime": datetime.now(tz=timezone.utc)} ) - columns = set(key for attrs in jobData.values() for key in attrs.keys()) + columns = set(key for attrs in job_data.values() for key in attrs.keys()) case_expressions = { column: case( *[ (Jobs.__table__.c.JobID == job_id, attrs[column]) - for job_id, attrs in jobData.items() + for job_id, attrs in job_data.items() if column in attrs ], else_=getattr(Jobs.__table__.c, column), # Retain original value @@ -235,25 +236,11 @@ async def setJobAttributesBulk(self, jobData): stmt = ( Jobs.__table__.update() .values(**case_expressions) - .where(Jobs.__table__.c.JobID.in_(jobData.keys())) + .where(Jobs.__table__.c.JobID.in_(job_data.keys())) ) await self.conn.execute(stmt) - async def getJobJDL(self, job_id: int, original: bool = False) -> str: - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL - - if original: - stmt = select(JobJDLs.OriginalJDL).where(JobJDLs.JobID == job_id) - else: - stmt = select(JobJDLs.JDL).where(JobJDLs.JobID == job_id) - - jdl = (await self.conn.execute(stmt)).scalar_one() - if jdl: - jdl = extractJDL(jdl) - - return jdl - - async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, str]: + async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int | str, str]: from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL if original: @@ -278,7 +265,7 @@ async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: **dict((await self.conn.execute(stmt)).one()._mapping) ) except NoResultFound as e: - raise JobNotFound(job_id) from e + raise JobNotFoundError(job_id) from e async def set_job_command(self, job_id: int, command: str, arguments: str = ""): """Store a command to be passed to the job together with the next heart beat.""" @@ -291,7 +278,7 @@ async def set_job_command(self, job_id: int, command: str, arguments: str = ""): ) await self.conn.execute(stmt) except IntegrityError as e: - raise JobNotFound(job_id) from e + raise JobNotFoundError(job_id) from e async def set_job_command_bulk(self, commands): """Store a command to be passed to the job together with the next heart beat.""" diff --git a/diracx-db/src/diracx/db/sql/job_logging/db.py b/diracx-db/src/diracx/db/sql/job_logging/db.py index 9774c523..d637dd48 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/db.py +++ b/diracx-db/src/diracx/db/sql/job_logging/db.py @@ -12,7 +12,7 @@ from collections import defaultdict -from diracx.core.exceptions import JobNotFound +from diracx.core.exceptions import JobNotFoundError from diracx.core.models import ( JobStatus, JobStatusReturn, @@ -212,7 +212,7 @@ async def get_wms_time_stamps(self, job_id): ).where(LoggingInfo.JobID == job_id) rows = await self.conn.execute(stmt) if not rows.rowcount: - raise JobNotFound(job_id) from None + raise JobNotFoundError(job_id) from None for event, etime in rows: result[event] = str(etime + MAGIC_EPOC_NUMBER) diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py index db72a7f9..28462778 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py @@ -5,10 +5,10 @@ import sqlalchemy from diracx.core.models import SandboxInfo, SandboxType, UserInfo -from diracx.db.sql.utils import BaseSQLDB, utcnow +from diracx.db.sql.utils import BaseSQLDB, UTCNow from .schema import Base as SandboxMetadataDBBase -from .schema import sb_EntityMapping, sb_Owners, sb_SandBoxes +from .schema import SandBoxes, SBEntityMapping, SBOwners class SandboxMetadataDB(BaseSQLDB): @@ -17,16 +17,16 @@ class SandboxMetadataDB(BaseSQLDB): async def upsert_owner(self, user: UserInfo) -> int: """Get the id of the owner from the database.""" # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 - stmt = sqlalchemy.select(sb_Owners.OwnerID).where( - sb_Owners.Owner == user.preferred_username, - sb_Owners.OwnerGroup == user.dirac_group, - sb_Owners.VO == user.vo, + stmt = sqlalchemy.select(SBOwners.OwnerID).where( + SBOwners.Owner == user.preferred_username, + SBOwners.OwnerGroup == user.dirac_group, + SBOwners.VO == user.vo, ) result = await self.conn.execute(stmt) if owner_id := result.scalar_one_or_none(): return owner_id - stmt = sqlalchemy.insert(sb_Owners).values( + stmt = sqlalchemy.insert(SBOwners).values( Owner=user.preferred_username, OwnerGroup=user.dirac_group, VO=user.vo, @@ -53,13 +53,13 @@ async def insert_sandbox( """Add a new sandbox in SandboxMetadataDB.""" # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 owner_id = await self.upsert_owner(user) - stmt = sqlalchemy.insert(sb_SandBoxes).values( + stmt = sqlalchemy.insert(SandBoxes).values( OwnerId=owner_id, SEName=se_name, SEPFN=pfn, Bytes=size, - RegistrationTime=utcnow(), - LastAccessTime=utcnow(), + RegistrationTime=UTCNow(), + LastAccessTime=UTCNow(), ) try: result = await self.conn.execute(stmt) @@ -70,17 +70,17 @@ async def insert_sandbox( async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None: stmt = ( - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn) - .values(LastAccessTime=utcnow()) + sqlalchemy.update(SandBoxes) + .where(SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn) + .values(LastAccessTime=UTCNow()) ) result = await self.conn.execute(stmt) assert result.rowcount == 1 async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool: """Checks if a sandbox exists and has been assigned.""" - stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where( - sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn + stmt: sqlalchemy.Executable = sqlalchemy.select(SandBoxes.Assigned).where( + SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn ) result = await self.conn.execute(stmt) is_assigned = result.scalar_one() @@ -97,11 +97,11 @@ async def get_sandbox_assigned_to_job( """Get the sandbox assign to job.""" entity_id = self.jobid_to_entity_id(job_id) stmt = ( - sqlalchemy.select(sb_SandBoxes.SEPFN) - .where(sb_SandBoxes.SBId == sb_EntityMapping.SBId) + sqlalchemy.select(SandBoxes.SEPFN) + .where(SandBoxes.SBId == SBEntityMapping.SBId) .where( - sb_EntityMapping.EntityId == entity_id, - sb_EntityMapping.Type == sb_type, + SBEntityMapping.EntityId == entity_id, + SBEntityMapping.Type == sb_type, ) ) result = await self.conn.execute(stmt) @@ -119,21 +119,21 @@ async def assign_sandbox_to_jobs( # Define the entity id as 'Entity:entity_id' due to the DB definition: entity_id = self.jobid_to_entity_id(job_id) select_sb_id = sqlalchemy.select( - sb_SandBoxes.SBId, + SandBoxes.SBId, sqlalchemy.literal(entity_id).label("EntityId"), sqlalchemy.literal(sb_type).label("Type"), ).where( - sb_SandBoxes.SEName == se_name, - sb_SandBoxes.SEPFN == pfn, + SandBoxes.SEName == se_name, + SandBoxes.SEPFN == pfn, ) - stmt = sqlalchemy.insert(sb_EntityMapping).from_select( + stmt = sqlalchemy.insert(SBEntityMapping).from_select( ["SBId", "EntityId", "Type"], select_sb_id ) await self.conn.execute(stmt) stmt = ( - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SEPFN == pfn) + sqlalchemy.update(SandBoxes) + .where(SandBoxes.SEPFN == pfn) .values(Assigned=True) ) result = await self.conn.execute(stmt) @@ -143,29 +143,29 @@ async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None: """Delete mapping between jobs and sandboxes.""" for job_id in jobs_ids: entity_id = self.jobid_to_entity_id(job_id) - sb_sel_stmt = sqlalchemy.select(sb_SandBoxes.SBId) + sb_sel_stmt = sqlalchemy.select(SandBoxes.SBId) sb_sel_stmt = sb_sel_stmt.join( - sb_EntityMapping, sb_EntityMapping.SBId == sb_SandBoxes.SBId + SBEntityMapping, SBEntityMapping.SBId == SandBoxes.SBId ) - sb_sel_stmt = sb_sel_stmt.where(sb_EntityMapping.EntityId == entity_id) + sb_sel_stmt = sb_sel_stmt.where(SBEntityMapping.EntityId == entity_id) result = await self.conn.execute(sb_sel_stmt) sb_ids = [row.SBId for row in result] - del_stmt = sqlalchemy.delete(sb_EntityMapping).where( - sb_EntityMapping.EntityId == entity_id + del_stmt = sqlalchemy.delete(SBEntityMapping).where( + SBEntityMapping.EntityId == entity_id ) await self.conn.execute(del_stmt) - sb_entity_sel_stmt = sqlalchemy.select(sb_EntityMapping.SBId).where( - sb_EntityMapping.SBId.in_(sb_ids) + sb_entity_sel_stmt = sqlalchemy.select(SBEntityMapping.SBId).where( + SBEntityMapping.SBId.in_(sb_ids) ) result = await self.conn.execute(sb_entity_sel_stmt) remaining_sb_ids = [row.SBId for row in result] if not remaining_sb_ids: unassign_stmt = ( - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SBId.in_(sb_ids)) + sqlalchemy.update(SandBoxes) + .where(SandBoxes.SBId.in_(sb_ids)) .values(Assigned=False) ) await self.conn.execute(unassign_stmt) diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py index 8c849c67..5864ea42 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py @@ -14,7 +14,7 @@ Base = declarative_base() -class sb_Owners(Base): +class SBOwners(Base): __tablename__ = "sb_Owners" OwnerID = Column(Integer, autoincrement=True) Owner = Column(String(32)) @@ -23,7 +23,7 @@ class sb_Owners(Base): __table_args__ = (PrimaryKeyConstraint("OwnerID"),) -class sb_SandBoxes(Base): +class SandBoxes(Base): __tablename__ = "sb_SandBoxes" SBId = Column(Integer, autoincrement=True) OwnerId = Column(Integer) @@ -40,7 +40,7 @@ class sb_SandBoxes(Base): ) -class sb_EntityMapping(Base): +class SBEntityMapping(Base): __tablename__ = "sb_EntityMapping" SBId = Column(Integer) EntityId = Column(String(128)) diff --git a/diracx-db/src/diracx/db/sql/task_queue/db.py b/diracx-db/src/diracx/db/sql/task_queue/db.py index 537f128e..ff701509 100644 --- a/diracx-db/src/diracx/db/sql/task_queue/db.py +++ b/diracx-db/src/diracx/db/sql/task_queue/db.py @@ -121,12 +121,12 @@ async def recalculate_tq_shares_for_entity( # TODO: I guess the rows are already a list of tupes # maybe refactor data = [(r[0], r[1]) for r in rows if r] - numOwners = len(data) + num_owners = len(data) # If there are no owners do now - if numOwners == 0: + if num_owners == 0: return # Split the share amongst the number of owners - entities_shares = {row[0]: job_share / numOwners for row in data} + entities_shares = {row[0]: job_share / num_owners for row in data} # TODO: implement the following # If corrector is enabled let it work it's magic diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 390588e6..69b3fa86 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -26,7 +26,7 @@ from diracx.core.extensions import select_from_extension from diracx.core.models import SortDirection from diracx.core.settings import SqlalchemyDsn -from diracx.db.exceptions import DBUnavailable +from diracx.db.exceptions import DBUnavailableError if TYPE_CHECKING: from sqlalchemy.types import TypeEngine @@ -34,32 +34,32 @@ logger = logging.getLogger(__name__) -class utcnow(expression.FunctionElement): +class UTCNow(expression.FunctionElement): type: TypeEngine = DateTime() inherit_cache: bool = True -@compiles(utcnow, "postgresql") +@compiles(UTCNow, "postgresql") def pg_utcnow(element, compiler, **kw) -> str: return "TIMEZONE('utc', CURRENT_TIMESTAMP)" -@compiles(utcnow, "mssql") +@compiles(UTCNow, "mssql") def ms_utcnow(element, compiler, **kw) -> str: return "GETUTCDATE()" -@compiles(utcnow, "mysql") +@compiles(UTCNow, "mysql") def mysql_utcnow(element, compiler, **kw) -> str: return "(UTC_TIMESTAMP)" -@compiles(utcnow, "sqlite") +@compiles(UTCNow, "sqlite") def sqlite_utcnow(element, compiler, **kw) -> str: return "DATETIME('now')" -class date_trunc(expression.FunctionElement): +class DateTrunc(expression.FunctionElement): """Sqlalchemy function to truncate a date to a given resolution. Primarily used to be able to query for a specific resolution of a date e.g. @@ -77,7 +77,7 @@ def __init__(self, *args, time_resolution, **kwargs) -> None: self._time_resolution = time_resolution -@compiles(date_trunc, "postgresql") +@compiles(DateTrunc, "postgresql") def pg_date_trunc(element, compiler, **kw): res = { "SECOND": "second", @@ -90,7 +90,7 @@ def pg_date_trunc(element, compiler, **kw): return f"date_trunc('{res}', {compiler.process(element.clauses)})" -@compiles(date_trunc, "mysql") +@compiles(DateTrunc, "mysql") def mysql_date_trunc(element, compiler, **kw): pattern = { "SECOND": "%Y-%m-%d %H:%i:%S", @@ -105,7 +105,7 @@ def mysql_date_trunc(element, compiler, **kw): return compiler.process(func.date_format(dt_col, pattern)) -@compiles(date_trunc, "sqlite") +@compiles(DateTrunc, "sqlite") def sqlite_date_trunc(element, compiler, **kw): pattern = { "SECOND": "%Y-%m-%d %H:%M:%S", @@ -130,10 +130,10 @@ def substract_date(**kwargs: float) -> datetime: Column: partial[RawColumn] = partial(RawColumn, nullable=False) NullColumn: partial[RawColumn] = partial(RawColumn, nullable=True) -DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=utcnow()) +DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=UTCNow()) -def EnumColumn(enum_type, **kwargs): +def EnumColumn(enum_type, **kwargs): # noqa: N802 return Column(Enum(enum_type, native_enum=False, length=16), **kwargs) @@ -167,7 +167,7 @@ class SQLDBError(Exception): pass -class SQLDBUnavailable(DBUnavailable, SQLDBError): +class SQLDBUnavailableError(DBUnavailableError, SQLDBError): """Used whenever we encounter a problem with the B connection.""" @@ -324,7 +324,7 @@ async def __aenter__(self) -> Self: try: self._conn.set(await self.engine.connect().__aenter__()) except Exception as e: - raise SQLDBUnavailable( + raise SQLDBUnavailableError( f"Cannot connect to {self.__class__.__name__}" ) from e @@ -350,7 +350,7 @@ async def ping(self): try: await self.conn.scalar(select(1)) except OperationalError as e: - raise SQLDBUnavailable("Cannot ping the DB") from e + raise SQLDBUnavailableError("Cannot ping the DB") from e def find_time_resolution(value): @@ -394,7 +394,7 @@ def apply_search_filters(column_mapping, stmt, search): if "value" in query and isinstance(query["value"], str): resolution, value = find_time_resolution(query["value"]) if resolution: - column = date_trunc(column, time_resolution=resolution) + column = DateTrunc(column, time_resolution=resolution) query["value"] = value if query.get("values"): @@ -406,7 +406,7 @@ def apply_search_filters(column_mapping, stmt, search): f"Cannot mix different time resolutions in {query=}" ) if resolution := resolutions[0]: - column = date_trunc(column, time_resolution=resolution) + column = DateTrunc(column, time_resolution=resolution) query["values"] = values if query["operator"] == "eq": diff --git a/diracx-db/src/diracx/db/sql/utils/job.py b/diracx-db/src/diracx/db/sql/utils/job.py index 16ed5ba7..7447a619 100644 --- a/diracx-db/src/diracx/db/sql/utils/job.py +++ b/diracx-db/src/diracx/db/sql/utils/job.py @@ -99,7 +99,7 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB): # TODO is this even needed? class_ad_job.insertAttributeInt("JobID", job_id) - await job_db.checkAndPrepareJob( + await job_db.check_and_prepare_job( job_id, class_ad_job, class_ad_req, @@ -243,7 +243,7 @@ def parse_jdl(job_id, job_jdl): job_jdls = { jobid: parse_jdl(jobid, jdl) for jobid, jdl in ( - (await job_db.getJobJDLs(surviving_job_ids, original=True)).items() + (await job_db.get_job_jdls(surviving_job_ids, original=True)).items() ) } @@ -319,7 +319,7 @@ def parse_jdl(job_id, job_jdl): # BULK JDL UPDATE # DATABASE OPERATION - await job_db.setJobJDLsBulk(jdl_changes) + await job_db.set_job_jdl_bulk(jdl_changes) return { "failed": failed, @@ -499,7 +499,7 @@ async def set_job_status_bulk( ) ) - await job_db.setJobAttributesBulk(job_attribute_updates) + await job_db.set_job_attributes_bulk(job_attribute_updates) await remove_jobs_from_task_queue( list(deletable_killable_jobs), config, task_queue_db, background_task diff --git a/diracx-db/tests/jobs/test_jobDB.py b/diracx-db/tests/jobs/test_job_db.py similarity index 99% rename from diracx-db/tests/jobs/test_jobDB.py rename to diracx-db/tests/jobs/test_job_db.py index aa17035b..060bd7d8 100644 --- a/diracx-db/tests/jobs/test_jobDB.py +++ b/diracx-db/tests/jobs/test_job_db.py @@ -2,7 +2,7 @@ import pytest -from diracx.core.exceptions import InvalidQueryError, JobNotFound +from diracx.core.exceptions import InvalidQueryError, JobNotFoundError from diracx.core.models import ( ScalarSearchOperator, ScalarSearchSpec, @@ -333,5 +333,5 @@ async def test_search_pagination(job_db): async def test_set_job_command_invalid_job_id(job_db: JobDB): """Test that setting a command for a non-existent job raises JobNotFound.""" async with job_db as job_db: - with pytest.raises(JobNotFound): + with pytest.raises(JobNotFoundError): await job_db.set_job_command(123456, "test_command") diff --git a/diracx-db/tests/jobs/test_jobLoggingDB.py b/diracx-db/tests/jobs/test_job_logging_db.py similarity index 100% rename from diracx-db/tests/jobs/test_jobLoggingDB.py rename to diracx-db/tests/jobs/test_job_logging_db.py diff --git a/diracx-db/tests/jobs/test_sandbox_metadata.py b/diracx-db/tests/jobs/test_sandbox_metadata.py index 06149189..bcb1c2cc 100644 --- a/diracx-db/tests/jobs/test_sandbox_metadata.py +++ b/diracx-db/tests/jobs/test_sandbox_metadata.py @@ -9,7 +9,7 @@ from diracx.core.models import SandboxInfo, UserInfo from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB -from diracx.db.sql.sandbox_metadata.schema import sb_EntityMapping, sb_SandBoxes +from diracx.db.sql.sandbox_metadata.schema import SandBoxes, SBEntityMapping @pytest.fixture @@ -89,7 +89,7 @@ async def _dump_db( """ async with sandbox_metadata_db: stmt = sqlalchemy.select( - sb_SandBoxes.SEPFN, sb_SandBoxes.OwnerId, sb_SandBoxes.LastAccessTime + SandBoxes.SEPFN, SandBoxes.OwnerId, SandBoxes.LastAccessTime ) res = await sandbox_metadata_db.conn.execute(stmt) return {row.SEPFN: (row.OwnerId, row.LastAccessTime) for row in res} @@ -109,7 +109,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( await sandbox_metadata_db.insert_sandbox(sandbox_se, user_info, pfn, 100) async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN) + stmt = sqlalchemy.select(SandBoxes.SBId, SandBoxes.SEPFN) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SEPFN: row.SBId for row in res} sb_id_1 = db_contents[pfn] @@ -120,7 +120,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( # Check there is no mapping async with sandbox_metadata_db: stmt = sqlalchemy.select( - sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type + SBEntityMapping.SBId, SBEntityMapping.EntityId, SBEntityMapping.Type ) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SBId: (row.EntityId, row.Type) for row in res} @@ -134,7 +134,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( # Check if sandbox and job are mapped async with sandbox_metadata_db: stmt = sqlalchemy.select( - sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type + SBEntityMapping.SBId, SBEntityMapping.EntityId, SBEntityMapping.Type ) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SBId: (row.EntityId, row.Type) for row in res} @@ -144,7 +144,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( assert sb_type == "Output" async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN) + stmt = sqlalchemy.select(SandBoxes.SBId, SandBoxes.SEPFN) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SEPFN: row.SBId for row in res} sb_id_1 = db_contents[pfn] @@ -158,8 +158,8 @@ async def test_assign_and_unsassign_sandbox_to_jobs( # Entity should not exists anymore async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_EntityMapping.SBId).where( - sb_EntityMapping.EntityId == entity_id_1 + stmt = sqlalchemy.select(SBEntityMapping.SBId).where( + SBEntityMapping.EntityId == entity_id_1 ) res = await sandbox_metadata_db.conn.execute(stmt) entity_sb_id = [row.SBId for row in res] @@ -170,7 +170,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( assert await sandbox_metadata_db.sandbox_is_assigned(pfn, sandbox_se) is False # Check the mapping has been deleted async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_EntityMapping.SBId) + stmt = sqlalchemy.select(SBEntityMapping.SBId) res = await sandbox_metadata_db.conn.execute(stmt) res_sb_id = [row.SBId for row in res] assert sb_id_1 not in res_sb_id diff --git a/diracx-db/tests/opensearch/test_connection.py b/diracx-db/tests/opensearch/test_connection.py index 4b2e3877..1e61760f 100644 --- a/diracx-db/tests/opensearch/test_connection.py +++ b/diracx-db/tests/opensearch/test_connection.py @@ -2,7 +2,7 @@ import pytest -from diracx.db.os.utils import OpenSearchDBUnavailable +from diracx.db.os.utils import OpenSearchDBUnavailableError from diracx.testing.osdb import OPENSEARCH_PORT, DummyOSDB, require_port_availability @@ -10,7 +10,7 @@ async def _ensure_db_unavailable(db: DummyOSDB): """Helper function which raises an exception if we manage to connect to the DB.""" async with db.client_context(): async with db: - with pytest.raises(OpenSearchDBUnavailable): + with pytest.raises(OpenSearchDBUnavailableError): await db.ping() diff --git a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py b/diracx-db/tests/pilot_agents/test_pilot_agents_db.py similarity index 100% rename from diracx-db/tests/pilot_agents/test_pilotAgentsDB.py rename to diracx-db/tests/pilot_agents/test_pilot_agents_db.py diff --git a/diracx-db/tests/test_dummyDB.py b/diracx-db/tests/test_dummy_db.py similarity index 90% rename from diracx-db/tests/test_dummyDB.py rename to diracx-db/tests/test_dummy_db.py index 90ed15d0..e7011539 100644 --- a/diracx-db/tests/test_dummyDB.py +++ b/diracx-db/tests/test_dummy_db.py @@ -7,7 +7,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.db.sql.dummy.db import DummyDB -from diracx.db.sql.utils import SQLDBUnavailable +from diracx.db.sql.utils import SQLDBUnavailableError # Each DB test class must defined a fixture looking like this one # It allows to get an instance of an in memory DB, @@ -44,14 +44,14 @@ async def test_insert_and_summary(dummy_db: DummyDB): # Check that there are now 10 cars assigned to a single driver async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # Test the selection async with dummy_db as dummy_db: result = await dummy_db.summary( - ["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] + ["owner_id"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] ) assert result[0]["count"] == 1 @@ -59,7 +59,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): async with dummy_db as dummy_db: with pytest.raises(InvalidQueryError): result = await dummy_db.summary( - ["ownerID"], + ["owner_id"], [ { "parameter": "model", @@ -73,7 +73,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): async def test_bad_connection(): dummy_db = DummyDB("mysql+aiomysql://tata:yoyo@db.invalid:3306/name") async with dummy_db.engine_context(): - with pytest.raises(SQLDBUnavailable): + with pytest.raises(SQLDBUnavailableError): async with dummy_db: dummy_db.ping() @@ -93,7 +93,7 @@ async def test_successful_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Add data @@ -104,7 +104,7 @@ async def test_successful_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # The connection is closed when the context manager is exited @@ -114,7 +114,7 @@ async def test_successful_transaction(dummy_db): # Start a new transaction # The previous data should still be there because the transaction was committed (successful) async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 @@ -134,7 +134,7 @@ async def test_failed_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Add data @@ -159,7 +159,7 @@ async def test_failed_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result @@ -203,7 +203,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Add data @@ -217,7 +217,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # This will raise an exception but the transaction will be rolled back @@ -231,7 +231,7 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Start a new transaction, this time we commit it manually @@ -240,7 +240,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Add data @@ -254,7 +254,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # Manually commit the transaction, and then raise an exception @@ -271,5 +271,5 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should be there because the transaction was committed before the exception async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 diff --git a/diracx-routers/src/diracx/routers/__init__.py b/diracx-routers/src/diracx/routers/__init__.py index b3725b2f..d17fbd8f 100644 --- a/diracx-routers/src/diracx/routers/__init__.py +++ b/diracx-routers/src/diracx/routers/__init__.py @@ -34,11 +34,11 @@ from uvicorn.logging import AccessFormatter, DefaultFormatter from diracx.core.config import ConfigSource -from diracx.core.exceptions import DiracError, DiracHttpResponse +from diracx.core.exceptions import DiracError, DiracHttpResponseError from diracx.core.extensions import select_from_extension from diracx.core.settings import ServiceSettingsBase from diracx.core.utils import dotenv_files_from_environment -from diracx.db.exceptions import DBUnavailable +from diracx.db.exceptions import DBUnavailableError from diracx.db.os.utils import BaseOSDB from diracx.db.sql.utils import BaseSQLDB from diracx.routers.access_policies import BaseAccessPolicy, check_permissions @@ -291,10 +291,10 @@ def create_app_inner( handler_signature = Callable[[Request, Exception], Response | Awaitable[Response]] app.add_exception_handler(DiracError, cast(handler_signature, dirac_error_handler)) app.add_exception_handler( - DiracHttpResponse, cast(handler_signature, http_response_handler) + DiracHttpResponseError, cast(handler_signature, http_response_handler) ) app.add_exception_handler( - DBUnavailable, cast(handler_signature, route_unavailable_error_hander) + DBUnavailableError, cast(handler_signature, route_unavailable_error_hander) ) # TODO: remove the CORSMiddleware once we figure out how to launch @@ -393,11 +393,11 @@ def dirac_error_handler(request: Request, exc: DiracError) -> Response: ) -def http_response_handler(request: Request, exc: DiracHttpResponse) -> Response: +def http_response_handler(request: Request, exc: DiracHttpResponseError) -> Response: return JSONResponse(status_code=exc.status_code, content=exc.data) -def route_unavailable_error_hander(request: Request, exc: DBUnavailable): +def route_unavailable_error_hander(request: Request, exc: DBUnavailableError): return JSONResponse( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, headers={"Retry-After": "10"}, @@ -435,7 +435,7 @@ async def is_db_unavailable(db: BaseSQLDB | BaseOSDB) -> str: await db.ping() _db_alive_cache[db] = "" - except DBUnavailable as e: + except DBUnavailableError as e: _db_alive_cache[db] = e.args[0] return _db_alive_cache[db] @@ -448,7 +448,7 @@ async def db_transaction(db: T2) -> AsyncGenerator[T2]: async with db: # Check whether the connection still works before executing the query if reason := await is_db_unavailable(db): - raise DBUnavailable(reason) + raise DBUnavailableError(reason) yield db diff --git a/diracx-routers/src/diracx/routers/auth/token.py b/diracx-routers/src/diracx/routers/auth/token.py index 14103add..8346e2b9 100644 --- a/diracx-routers/src/diracx/routers/auth/token.py +++ b/diracx-routers/src/diracx/routers/auth/token.py @@ -12,7 +12,7 @@ from fastapi import Depends, Form, Header, HTTPException, status from diracx.core.exceptions import ( - DiracHttpResponse, + DiracHttpResponseError, ExpiredFlowError, PendingAuthorizationError, ) @@ -120,15 +120,15 @@ async def get_oidc_token_info_from_device_flow( device_code, settings.device_flow_expiration_seconds ) except PendingAuthorizationError as e: - raise DiracHttpResponse( + raise DiracHttpResponseError( status.HTTP_400_BAD_REQUEST, {"error": "authorization_pending"} ) from e except ExpiredFlowError as e: - raise DiracHttpResponse( + raise DiracHttpResponseError( status.HTTP_401_UNAUTHORIZED, {"error": "expired_token"} ) from e - # raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"}) - # raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) + # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"}) + # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) if info["client_id"] != client_id: raise HTTPException( diff --git a/diracx-routers/src/diracx/routers/auth/utils.py b/diracx-routers/src/diracx/routers/auth/utils.py index 3b881361..7ca8b523 100644 --- a/diracx-routers/src/diracx/routers/auth/utils.py +++ b/diracx-routers/src/diracx/routers/auth/utils.py @@ -262,7 +262,7 @@ async def initiate_authorization_flow_with_iam( state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite ) - urlParams = [ + url_params = [ "response_type=code", f"code_challenge={code_challenge}", "code_challenge_method=S256", @@ -271,7 +271,7 @@ async def initiate_authorization_flow_with_iam( "scope=openid%20profile", f"state={encrypted_state}", ] - authorization_flow_url = f"{authorization_endpoint}?{'&'.join(urlParams)}" + authorization_flow_url = f"{authorization_endpoint}?{'&'.join(url_params)}" return authorization_flow_url diff --git a/diracx-routers/tests/test_job_manager.py b/diracx-routers/tests/test_job_manager.py index 59b1f6c0..4a81d7e9 100644 --- a/diracx-routers/tests/test_job_manager.py +++ b/diracx-routers/tests/test_job_manager.py @@ -482,8 +482,8 @@ async def test_get_job_status_history( assert r.json()[0]["MinorStatus"] == "Job accepted" assert r.json()[0]["ApplicationStatus"] == "Unknown" - NEW_STATUS = JobStatus.CHECKING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.CHECKING.value + new_minor_status = "JobPath" before = datetime.now(timezone.utc) r = normal_user_client.patch( @@ -491,8 +491,8 @@ async def test_get_job_status_history( json={ valid_job_id: { datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } }, @@ -501,8 +501,8 @@ async def test_get_job_status_history( after = datetime.now(timezone.utc) assert r.status_code == 200, r.json() - assert r.json()["success"][str(valid_job_id)]["Status"] == NEW_STATUS - assert r.json()["success"][str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()["success"][str(valid_job_id)]["Status"] == new_status + assert r.json()["success"][str(valid_job_id)]["MinorStatus"] == new_minor_status # Act r = normal_user_client.post( @@ -588,15 +588,15 @@ def test_set_job_status(normal_user_client: TestClient, valid_job_id: int): assert j["ApplicationStatus"] == "Unknown" # Act - NEW_STATUS = JobStatus.CHECKING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.CHECKING.value + new_minor_status = "JobPath" r = normal_user_client.patch( "/api/jobs/status", json={ valid_job_id: { datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } }, @@ -604,8 +604,8 @@ def test_set_job_status(normal_user_client: TestClient, valid_job_id: int): # Assert assert r.status_code == 200, r.json() - assert r.json()["success"][str(valid_job_id)]["Status"] == NEW_STATUS - assert r.json()["success"][str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()["success"][str(valid_job_id)]["Status"] == new_status + assert r.json()["success"][str(valid_job_id)]["MinorStatus"] == new_minor_status r = normal_user_client.post( "/api/jobs/search", @@ -621,8 +621,8 @@ def test_set_job_status(normal_user_client: TestClient, valid_job_id: int): ) assert r.status_code == 200, r.json() assert r.json()[0]["JobID"] == valid_job_id - assert r.json()[0]["Status"] == NEW_STATUS - assert r.json()[0]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[0]["Status"] == new_status + assert r.json()[0]["MinorStatus"] == new_minor_status assert r.json()[0]["ApplicationStatus"] == "Unknown" @@ -700,15 +700,15 @@ def test_set_job_status_cannot_make_impossible_transitions( assert r.json()[0]["ApplicationStatus"] == "Unknown" # Act - NEW_STATUS = JobStatus.RUNNING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.RUNNING.value + new_minor_status = "JobPath" r = normal_user_client.patch( "/api/jobs/status", json={ valid_job_id: { datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } }, @@ -718,8 +718,8 @@ def test_set_job_status_cannot_make_impossible_transitions( assert r.status_code == 200, r.json() success = r.json()["success"] assert len(success) == 1, r.json() - assert success[str(valid_job_id)]["Status"] != NEW_STATUS - assert success[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert success[str(valid_job_id)]["Status"] != new_status + assert success[str(valid_job_id)]["MinorStatus"] == new_minor_status r = normal_user_client.post( "/api/jobs/search", @@ -734,8 +734,8 @@ def test_set_job_status_cannot_make_impossible_transitions( }, ) assert r.status_code == 200, r.json() - assert r.json()[0]["Status"] != NEW_STATUS - assert r.json()[0]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[0]["Status"] != new_status + assert r.json()[0]["MinorStatus"] == new_minor_status assert r.json()[0]["ApplicationStatus"] == "Unknown" @@ -760,15 +760,15 @@ def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int) assert r.json()[0]["ApplicationStatus"] == "Unknown" # Act - NEW_STATUS = JobStatus.RUNNING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.RUNNING.value + new_minor_status = "JobPath" r = normal_user_client.patch( "/api/jobs/status", json={ valid_job_id: { datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } }, @@ -779,8 +779,8 @@ def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int) # Assert assert r.status_code == 200, r.json() - assert success[str(valid_job_id)]["Status"] == NEW_STATUS - assert success[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert success[str(valid_job_id)]["Status"] == new_status + assert success[str(valid_job_id)]["MinorStatus"] == new_minor_status r = normal_user_client.post( "/api/jobs/search", @@ -796,8 +796,8 @@ def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int) ) assert r.status_code == 200, r.json() assert r.json()[0]["JobID"] == valid_job_id - assert r.json()[0]["Status"] == NEW_STATUS - assert r.json()[0]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[0]["Status"] == new_status + assert r.json()[0]["MinorStatus"] == new_minor_status assert r.json()[0]["ApplicationStatus"] == "Unknown" @@ -822,15 +822,15 @@ def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): assert r.json()[0]["MinorStatus"] == "Bulk transaction confirmation" # Act - NEW_STATUS = JobStatus.CHECKING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.CHECKING.value + new_minor_status = "JobPath" r = normal_user_client.patch( "/api/jobs/status", json={ job_id: { datetime.now(timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } for job_id in valid_job_ids @@ -842,8 +842,8 @@ def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): # Assert assert r.status_code == 200, r.json() for job_id in valid_job_ids: - assert success[str(job_id)]["Status"] == NEW_STATUS - assert success[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert success[str(job_id)]["Status"] == new_status + assert success[str(job_id)]["MinorStatus"] == new_minor_status r_get = normal_user_client.post( "/api/jobs/search", @@ -859,8 +859,8 @@ def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): ) assert r_get.status_code == 200, r_get.json() assert r_get.json()[0]["JobID"] == job_id - assert r_get.json()[0]["Status"] == NEW_STATUS - assert r_get.json()[0]["MinorStatus"] == NEW_MINOR_STATUS + assert r_get.json()[0]["Status"] == new_status + assert r_get.json()[0]["MinorStatus"] == new_minor_status assert r_get.json()[0]["ApplicationStatus"] == "Unknown" diff --git a/diracx-testing/src/diracx/testing/__init__.py b/diracx-testing/src/diracx/testing/__init__.py index b73696b1..59ebca1d 100644 --- a/diracx-testing/src/diracx/testing/__init__.py +++ b/diracx-testing/src/diracx/testing/__init__.py @@ -170,11 +170,16 @@ class AlwaysAllowAccessPolicy(BaseAccessPolicy): """Dummy access policy.""" async def policy( - policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs + policy_name: str, # noqa: N805 + user_info: AuthorizedUserInfo, + /, + **kwargs, ): pass - def enrich_tokens(access_payload: dict, refresh_payload: dict): + def enrich_tokens( + access_payload: dict, refresh_payload: dict # noqa: N805 + ): return {"PolicySpecific": "OpenAccessForTest"}, {} diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 282128ac..6e181a79 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -42,8 +42,8 @@ def __init__(self, connection_kwargs: dict[str, Any]) -> None: from diracx.db.sql.utils import DateNowColumn # Dynamically create a subclass of BaseSQLDB so we get clearer errors - MockedDB = type(f"Mocked{self.__class__.__name__}", (sql_utils.BaseSQLDB,), {}) - self._sql_db = MockedDB(connection_kwargs["sqlalchemy_dsn"]) + mocked_db = type(f"Mocked{self.__class__.__name__}", (sql_utils.BaseSQLDB,), {}) + self._sql_db = mocked_db(connection_kwargs["sqlalchemy_dsn"]) # Dynamically create the table definition based on the fields columns = [ @@ -53,16 +53,16 @@ def __init__(self, connection_kwargs: dict[str, Any]) -> None: for field, field_type in self.fields.items(): match field_type["type"]: case "date": - ColumnType = DateNowColumn + column_type = DateNowColumn case "long": - ColumnType = partial(Column, type_=Integer) + column_type = partial(Column, type_=Integer) case "keyword": - ColumnType = partial(Column, type_=String(255)) + column_type = partial(Column, type_=String(255)) case "text": - ColumnType = partial(Column, type_=String(64 * 1024)) + column_type = partial(Column, type_=String(64 * 1024)) case _: raise NotImplementedError(f"Unknown field type: {field_type=}") - columns.append(ColumnType(field, default=None)) + columns.append(column_type(field, default=None)) self._sql_db.metadata = MetaData() self._table = Table("dummy", self._sql_db.metadata, *columns) @@ -158,6 +158,6 @@ def fake_available_osdb_implementations(name, *, real_available_implementations) # Dynamically generate a class that inherits from the first implementation # but that also has the MockOSDBMixin - MockParameterDB = type(name, (MockOSDBMixin, implementations[0]), {}) + mock_parameter_db = type(name, (MockOSDBMixin, implementations[0]), {}) - return [MockParameterDB] + implementations + return [mock_parameter_db] + implementations diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py index 8f56ce4e..4357c574 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py @@ -20,7 +20,7 @@ async def insert_gubbins_info(self, job_id: int, info: str): stmt = insert(GubbinsInfo).values(JobID=job_id, Info=info) await self.conn.execute(stmt) - async def getJobJDL( # type: ignore[override] + async def get_job_jdl( # type: ignore[override] self, job_id: int, original: bool = False, with_info=False ) -> str | dict[str, str]: """ @@ -31,7 +31,7 @@ async def getJobJDL( # type: ignore[override] Note that this requires to disable mypy error with # type: ignore[override] """ - jdl = await super().getJobJDL(job_id, original=original) + jdl = await super().get_job_jdl(job_id, original=original) if not with_info: return jdl @@ -40,7 +40,7 @@ async def getJobJDL( # type: ignore[override] info = (await self.conn.execute(stmt)).scalar_one() return {"JDL": jdl, "Info": info} - async def setJobAttributesBulk(self, jobData): + async def set_job_attributes_bulk(self, job_data): """ This method modified the one in the parent class, without changing the argument nor the return type diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py index 5ce64edc..dc73d3b1 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py @@ -25,7 +25,7 @@ class LollygagDB(BaseSQLDB): async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = [Cars.__table__.columns[x] for x in group_by] - stmt = select(*columns, func.count(Cars.licensePlate).label("count")) + stmt = select(*columns, func.count(Cars.license_plate).label("count")) stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search) stmt = stmt.group_by(*columns) @@ -48,7 +48,7 @@ async def get_owner(self) -> list[str]: async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int: stmt = insert(Cars).values( - licensePlate=license_plate, model=model, ownerID=owner_id + license_plate=license_plate, model=model, owner_id=owner_id ) result = await self.conn.execute(stmt) diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py index 9e7b4eba..ff3f3000 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py @@ -9,13 +9,13 @@ class Owners(Base): __tablename__ = "Owners" - ownerID = Column(Integer, primary_key=True, autoincrement=True) + owner_id = Column(Integer, primary_key=True, autoincrement=True) creation_time = DateNowColumn() name = Column(String(255)) class Cars(Base): __tablename__ = "Cars" - licensePlate = Column(Uuid(), primary_key=True) + license_plate = Column(Uuid(), primary_key=True) model = Column(String(255)) - ownerID = Column(Integer, ForeignKey(Owners.ownerID)) + owner_id = Column(Integer, ForeignKey(Owners.owner_id)) diff --git a/extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py b/extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py similarity index 91% rename from extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py rename to extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py index f98e3bdf..b64882ef 100644 --- a/extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py +++ b/extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py @@ -46,9 +46,9 @@ async def test_gubbins_info(gubbins_db): await gubbins_db.insert_gubbins_info(job_id, "info") - result = await gubbins_db.getJobJDL(job_id, original=True) + result = await gubbins_db.get_job_jdl(job_id, original=True) assert result == "[JDL]" - result = await gubbins_db.getJobJDL(job_id, with_info=True) + result = await gubbins_db.get_job_jdl(job_id, with_info=True) assert "JDL" in result assert result["Info"] == "info" diff --git a/extensions/gubbins/gubbins-db/tests/test_lollygagDB.py b/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py similarity index 89% rename from extensions/gubbins/gubbins-db/tests/test_lollygagDB.py rename to extensions/gubbins/gubbins-db/tests/test_lollygag_db.py index f963ded1..5da1f9d8 100644 --- a/extensions/gubbins/gubbins-db/tests/test_lollygagDB.py +++ b/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py @@ -6,7 +6,7 @@ import pytest from diracx.core.exceptions import InvalidQueryError -from diracx.db.sql.utils import SQLDBUnavailable +from diracx.db.sql.utils import SQLDBUnavailableError from gubbins.db.sql.lollygag.db import LollygagDB @@ -51,14 +51,14 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): # Check that there are now 10 cars assigned to a single driver async with lollygag_db as lollygag_db: - result = await lollygag_db.summary(["ownerID"], []) + result = await lollygag_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # Test the selection async with lollygag_db as lollygag_db: result = await lollygag_db.summary( - ["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] + ["owner_id"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] ) assert result[0]["count"] == 1 @@ -66,7 +66,7 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): async with lollygag_db as lollygag_db: with pytest.raises(InvalidQueryError): result = await lollygag_db.summary( - ["ownerID"], + ["owner_id"], [ { "parameter": "model", @@ -80,6 +80,6 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): async def test_bad_connection(): lollygag_db = LollygagDB("mysql+aiomysql://tata:yoyo@db.invalid:3306/name") async with lollygag_db.engine_context(): - with pytest.raises(SQLDBUnavailable): + with pytest.raises(SQLDBUnavailableError): async with lollygag_db: lollygag_db.ping() diff --git a/extensions/gubbins/pyproject.toml b/extensions/gubbins/pyproject.toml index a10370f5..c61127cb 100644 --- a/extensions/gubbins/pyproject.toml +++ b/extensions/gubbins/pyproject.toml @@ -52,6 +52,7 @@ select = [ "FLY", # flynt "DTZ", # flake8-datetimez "S", # flake8-bandit + "N", # pep8-naming ] ignore = [ diff --git a/pyproject.toml b/pyproject.toml index 06b78e3b..77998d3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ select = [ "FLY", # flynt "DTZ", # flake8-datetimez "S", # flake8-bandit + "N", # pep8-naming ] ignore = [ "B905", diff --git a/run_local.sh b/run_local.sh index b632d248..83bfbcc4 100755 --- a/run_local.sh +++ b/run_local.sh @@ -70,7 +70,7 @@ echo "" echo "1. Use the CLI:" echo "" echo " export DIRACX_URL=http://localhost:8000" -echo " env DIRACX_SERVICE_AUTH_STATE_KEY='${state_key}' tests/make-token-local.py ${signing_key}" +echo " env DIRACX_SERVICE_AUTH_STATE_KEY='${state_key}' tests/make_token_local.py ${signing_key}" echo "" echo "2. Using swagger: http://localhost:8000/api/docs" diff --git a/tests/make-token-local.py b/tests/make_token_local.py similarity index 100% rename from tests/make-token-local.py rename to tests/make_token_local.py