Skip to content

Commit

Permalink
feat(pyproject): add N rule to ruff config
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr committed Dec 20, 2024
1 parent 15a1926 commit a7b2f36
Show file tree
Hide file tree
Showing 35 changed files with 221 additions and 228 deletions.
6 changes: 3 additions & 3 deletions diracx-core/src/diracx/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from cachetools import Cache, LRUCache, TTLCache, cachedmethod
from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints

from ..exceptions import BadConfigurationVersion
from ..exceptions import BadConfigurationVersionError
from ..extensions import select_from_extension
from .schema import Config

Expand Down Expand Up @@ -159,7 +159,7 @@ def latest_revision(self) -> tuple[str, datetime]:
).strip()
modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc)
except sh.ErrorReturnCode as e:
raise BadConfigurationVersion(f"Error parsing latest revision: {e}") from e
raise BadConfigurationVersionError(f"Error parsing latest revision: {e}") from e
logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified)
return rev, modified

Expand All @@ -176,7 +176,7 @@ def read_raw(self, hexsha: str, modified: datetime) -> Config:
)
raw_obj = yaml.safe_load(blob)
except sh.ErrorReturnCode as e:
raise BadConfigurationVersion(f"Error reading configuration: {e}") from e
raise BadConfigurationVersionError(f"Error reading configuration: {e}") from e

config_class: Config = select_from_extension(group="diracx", name="config")[
0
Expand Down
1 change: 0 additions & 1 deletion diracx-core/src/diracx/core/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ class DIRACConfig(BaseModel):

class JobMonitoringConfig(BaseModel):
GlobalJobsInfo: bool = True
useESForJobParametersFlag: bool = False


class JobSchedulingConfig(BaseModel):
Expand Down
6 changes: 3 additions & 3 deletions diracx-core/src/diracx/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from http import HTTPStatus


class DiracHttpResponse(RuntimeError):
class DiracHttpResponseError(RuntimeError):
def __init__(self, status_code: int, data):
self.status_code = status_code
self.data = data
Expand Down Expand Up @@ -30,15 +30,15 @@ class ConfigurationError(DiracError):
"""Used whenever we encounter a problem with the configuration."""


class BadConfigurationVersion(ConfigurationError):
class BadConfigurationVersionError(ConfigurationError):
"""The requested version is not known."""


class InvalidQueryError(DiracError):
"""It was not possible to build a valid database query from the given input."""


class JobNotFound(Exception):
class JobNotFoundError(Exception):
def __init__(self, job_id: int, detail: str | None = None):
self.job_id: int = job_id
super().__init__(f"Job {job_id} not found" + (" ({detail})" if detail else ""))
Expand Down
2 changes: 1 addition & 1 deletion diracx-db/src/diracx/db/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
class DBUnavailable(Exception):
class DBUnavailableError(Exception):
pass
6 changes: 3 additions & 3 deletions diracx-db/src/diracx/db/os/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from diracx.core.exceptions import InvalidQueryError
from diracx.core.extensions import select_from_extension
from diracx.db.exceptions import DBUnavailable
from diracx.db.exceptions import DBUnavailableError

logger = logging.getLogger(__name__)

Expand All @@ -25,7 +25,7 @@ class OpenSearchDBError(Exception):
pass


class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError):
class OpenSearchDBUnavailableError(DBUnavailableError, OpenSearchDBError):
pass


Expand Down Expand Up @@ -152,7 +152,7 @@ async def ping(self):
be ran at every query.
"""
if not await self.client.ping():
raise OpenSearchDBUnavailable(
raise OpenSearchDBUnavailableError(
f"Failed to connect to {self.__class__.__qualname__}"
)

Expand Down
4 changes: 2 additions & 2 deletions diracx-db/src/diracx/db/sql/dummy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class DummyDB(BaseSQLDB):
async def summary(self, group_by, search) -> list[dict[str, str | int]]:
columns = [Cars.__table__.columns[x] for x in group_by]

stmt = select(*columns, func.count(Cars.licensePlate).label("count"))
stmt = select(*columns, func.count(Cars.license_plate).label("count"))
stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
stmt = stmt.group_by(*columns)

Expand All @@ -44,7 +44,7 @@ async def insert_owner(self, name: str) -> int:

async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int:
stmt = insert(Cars).values(
licensePlate=license_plate, model=model, ownerID=owner_id
license_plate=license_plate, model=model, owner_id=owner_id
)

result = await self.conn.execute(stmt)
Expand Down
6 changes: 3 additions & 3 deletions diracx-db/src/diracx/db/sql/dummy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

class Owners(Base):
__tablename__ = "Owners"
ownerID = Column(Integer, primary_key=True, autoincrement=True)
owner_id = Column(Integer, primary_key=True, autoincrement=True)
creation_time = DateNowColumn()
name = Column(String(255))


class Cars(Base):
__tablename__ = "Cars"
licensePlate = Column(Uuid(), primary_key=True)
license_plate = Column(Uuid(), primary_key=True)
model = Column(String(255))
ownerID = Column(Integer, ForeignKey(Owners.ownerID))
owner_id = Column(Integer, ForeignKey(Owners.owner_id))
67 changes: 27 additions & 40 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

if TYPE_CHECKING:
from sqlalchemy.sql.elements import BindParameter
from diracx.core.exceptions import InvalidQueryError, JobNotFound

from diracx.core.exceptions import InvalidQueryError, JobNotFoundError
from diracx.core.models import (
LimitedJobStatusReturn,
SearchSpec,
Expand Down Expand Up @@ -42,7 +43,7 @@ class JobDB(BaseSQLDB):
# TODO: this is copied from the DIRAC JobDB
# but is overwriten in LHCbDIRAC, so we need
# to find a way to make it dynamic
jdl2DBParameters = ["JobName", "JobType", "JobGroup"]
jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"]

async def summary(self, group_by, search) -> list[dict[str, str | int]]:
columns = _get_columns(Jobs.__table__, group_by)
Expand Down Expand Up @@ -110,11 +111,11 @@ async def insert_input_data(self, lfns: dict[int, list[str]]):
],
)

async def setJobAttributes(self, job_id, jobData):
async def set_job_attributes(self, job_id, job_data):
"""TODO: add myDate and force parameters."""
if "Status" in jobData:
jobData = jobData | {"LastUpdateTime": datetime.now(tz=timezone.utc)}
stmt = update(Jobs).where(Jobs.JobID == job_id).values(jobData)
if "Status" in job_data:
job_data = job_data | {"LastUpdateTime": datetime.now(tz=timezone.utc)}
stmt = update(Jobs).where(Jobs.JobID == job_id).values(job_data)
await self.conn.execute(stmt)

async def create_job(self, original_jdl):
Expand Down Expand Up @@ -159,9 +160,9 @@ async def update_job_jdls(self, jdls_to_update: dict[int, str]):
],
)

async def checkAndPrepareJob(
async def check_and_prepare_job(
self,
jobID,
job_id,
class_ad_job,
class_ad_req,
owner,
Expand All @@ -178,8 +179,8 @@ async def checkAndPrepareJob(
checkAndPrepareJob,
)

retVal = checkAndPrepareJob(
jobID,
ret_val = checkAndPrepareJob(
job_id,
class_ad_job,
class_ad_req,
owner,
Expand All @@ -188,21 +189,21 @@ async def checkAndPrepareJob(
vo,
)

if not retVal["OK"]:
if cmpError(retVal, EWMSSUBM):
await self.setJobAttributes(jobID, job_attrs)
if not ret_val["OK"]:
if cmpError(ret_val, EWMSSUBM):
await self.set_job_attributes(job_id, job_attrs)

returnValueOrRaise(retVal)
returnValueOrRaise(ret_val)

async def setJobJDL(self, job_id, jdl):
async def set_job_jdl(self, job_id, jdl):
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL

stmt = (
update(JobJDLs).where(JobJDLs.JobID == job_id).values(JDL=compressJDL(jdl))
)
await self.conn.execute(stmt)

async def setJobJDLsBulk(self, jdls):
async def set_job_jdl_bulk(self, jdls):
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL

await self.conn.execute(
Expand All @@ -212,19 +213,19 @@ async def setJobJDLsBulk(self, jdls):
[{"b_JobID": jid, "JDL": compressJDL(jdl)} for jid, jdl in jdls.items()],
)

async def setJobAttributesBulk(self, jobData):
async def set_job_attributes_bulk(self, job_data):
"""TODO: add myDate and force parameters."""
for job_id in jobData.keys():
if "Status" in jobData[job_id]:
jobData[job_id].update(
for job_id in job_data.keys():
if "Status" in job_data[job_id]:
job_data[job_id].update(
{"LastUpdateTime": datetime.now(tz=timezone.utc)}
)
columns = set(key for attrs in jobData.values() for key in attrs.keys())
columns = set(key for attrs in job_data.values() for key in attrs.keys())
case_expressions = {
column: case(
*[
(Jobs.__table__.c.JobID == job_id, attrs[column])
for job_id, attrs in jobData.items()
for job_id, attrs in job_data.items()
if column in attrs
],
else_=getattr(Jobs.__table__.c, column), # Retain original value
Expand All @@ -235,25 +236,11 @@ async def setJobAttributesBulk(self, jobData):
stmt = (
Jobs.__table__.update()
.values(**case_expressions)
.where(Jobs.__table__.c.JobID.in_(jobData.keys()))
.where(Jobs.__table__.c.JobID.in_(job_data.keys()))
)
await self.conn.execute(stmt)

async def getJobJDL(self, job_id: int, original: bool = False) -> str:
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL

if original:
stmt = select(JobJDLs.OriginalJDL).where(JobJDLs.JobID == job_id)
else:
stmt = select(JobJDLs.JDL).where(JobJDLs.JobID == job_id)

jdl = (await self.conn.execute(stmt)).scalar_one()
if jdl:
jdl = extractJDL(jdl)

return jdl

async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, str]:
async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int | str, str]:
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL

if original:
Expand All @@ -278,7 +265,7 @@ async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn:
**dict((await self.conn.execute(stmt)).one()._mapping)
)
except NoResultFound as e:
raise JobNotFound(job_id) from e
raise JobNotFoundError(job_id) from e

async def set_job_command(self, job_id: int, command: str, arguments: str = ""):
"""Store a command to be passed to the job together with the next heart beat."""
Expand All @@ -291,7 +278,7 @@ async def set_job_command(self, job_id: int, command: str, arguments: str = ""):
)
await self.conn.execute(stmt)
except IntegrityError as e:
raise JobNotFound(job_id) from e
raise JobNotFoundError(job_id) from e

async def set_job_command_bulk(self, commands):
"""Store a command to be passed to the job together with the next heart beat."""
Expand Down
4 changes: 2 additions & 2 deletions diracx-db/src/diracx/db/sql/job_logging/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from collections import defaultdict

from diracx.core.exceptions import JobNotFound
from diracx.core.exceptions import JobNotFoundError
from diracx.core.models import (
JobStatus,
JobStatusReturn,
Expand Down Expand Up @@ -212,7 +212,7 @@ async def get_wms_time_stamps(self, job_id):
).where(LoggingInfo.JobID == job_id)
rows = await self.conn.execute(stmt)
if not rows.rowcount:
raise JobNotFound(job_id) from None
raise JobNotFoundError(job_id) from None

for event, etime in rows:
result[event] = str(etime + MAGIC_EPOC_NUMBER)
Expand Down
Loading

0 comments on commit a7b2f36

Please sign in to comment.