Skip to content

Commit

Permalink
feat(consistency): make SQL Alchemy interfaces consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr committed Dec 20, 2024
1 parent a7b2f36 commit 7f6f884
Show file tree
Hide file tree
Showing 23 changed files with 355 additions and 317 deletions.
18 changes: 9 additions & 9 deletions diracx-db/src/diracx/db/sql/auth/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ async def get_device_flow(self, device_code: str, max_validity: int):
stmt = select(
DeviceFlows,
(DeviceFlows.creation_time < substract_date(seconds=max_validity)).label(
"is_expired"
"IsExpired"
),
).with_for_update()
stmt = stmt.where(
DeviceFlows.device_code == hashlib.sha256(device_code.encode()).hexdigest(),
)
res = dict((await self.conn.execute(stmt)).one()._mapping)

if res["is_expired"]:
if res["IsExpired"]:
raise ExpiredFlowError()

if res["status"] == FlowStatus.READY:
if res["Status"] == FlowStatus.READY:
# Update the status to Done before returning
await self.conn.execute(
update(DeviceFlows)
Expand All @@ -81,10 +81,10 @@ async def get_device_flow(self, device_code: str, max_validity: int):
)
return res

if res["status"] == FlowStatus.DONE:
if res["Status"] == FlowStatus.DONE:
raise AuthorizationError("Code was already used")

if res["status"] == FlowStatus.PENDING:
if res["Status"] == FlowStatus.PENDING:
raise PendingAuthorizationError()

raise AuthorizationError("Bad state in device flow")
Expand Down Expand Up @@ -190,7 +190,7 @@ async def authorization_flow_insert_id_token(
stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri)
stmt = stmt.where(AuthorizationFlows.uuid == uuid)
row = (await self.conn.execute(stmt)).one()
return code, row.redirect_uri
return code, row.RedirectURI

async def get_authorization_flow(self, code: str, max_validity: int):
hashed_code = hashlib.sha256(code.encode()).hexdigest()
Expand All @@ -205,7 +205,7 @@ async def get_authorization_flow(self, code: str, max_validity: int):

res = dict((await self.conn.execute(stmt)).one()._mapping)

if res["status"] == FlowStatus.READY:
if res["Status"] == FlowStatus.READY:
# Update the status to Done before returning
await self.conn.execute(
update(AuthorizationFlows)
Expand All @@ -215,7 +215,7 @@ async def get_authorization_flow(self, code: str, max_validity: int):

return res

if res["status"] == FlowStatus.DONE:
if res["Status"] == FlowStatus.DONE:
raise AuthorizationError("Code was already used")

raise AuthorizationError("Bad state in authorization flow")
Expand Down Expand Up @@ -247,7 +247,7 @@ async def insert_refresh_token(
row = (await self.conn.execute(stmt)).one()

# Return the JWT ID and the creation time
return jti, row.creation_time
return jti, row.CreationTime

async def get_refresh_token(self, jti: str) -> dict:
"""Get refresh token details bound to a given JWT ID."""
Expand Down
46 changes: 23 additions & 23 deletions diracx-db/src/diracx/db/sql/auth/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,27 @@ class FlowStatus(Enum):

class DeviceFlows(Base):
__tablename__ = "DeviceFlows"
user_code = Column(String(USER_CODE_LENGTH), primary_key=True)
status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name)
creation_time = DateNowColumn()
client_id = Column(String(255))
scope = Column(String(1024))
device_code = Column(String(128), unique=True) # Should be a hash
id_token = NullColumn(JSON())
user_code = Column("UserCode", String(USER_CODE_LENGTH), primary_key=True)
status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
creation_time = DateNowColumn("CreationTime")
client_id = Column("ClientID", String(255))
scope = Column("Scope", String(1024))
device_code = Column("DeviceCode", String(128), unique=True) # Should be a hash
id_token = NullColumn("IDToken", JSON())


class AuthorizationFlows(Base):
__tablename__ = "AuthorizationFlows"
uuid = Column(Uuid(as_uuid=False), primary_key=True)
status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name)
client_id = Column(String(255))
creation_time = DateNowColumn()
scope = Column(String(1024))
code_challenge = Column(String(255))
code_challenge_method = Column(String(8))
redirect_uri = Column(String(255))
code = NullColumn(String(255)) # Should be a hash
id_token = NullColumn(JSON())
uuid = Column("UUID", Uuid(as_uuid=False), primary_key=True)
status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
client_id = Column("ClientID", String(255))
creation_time = DateNowColumn("CretionTime")
scope = Column("Scope", String(1024))
code_challenge = Column("CodeChallenge", String(255))
code_challenge_method = Column("CodeChallengeMethod", String(8))
redirect_uri = Column("RedirectURI", String(255))
code = NullColumn("Code", String(255)) # Should be a hash
id_token = NullColumn("IDToken", JSON())


class RefreshTokenStatus(Enum):
Expand All @@ -85,13 +85,13 @@ class RefreshTokens(Base):

__tablename__ = "RefreshTokens"
# Refresh token attributes
jti = Column(Uuid(as_uuid=False), primary_key=True)
jti = Column("JTI", Uuid(as_uuid=False), primary_key=True)
status = EnumColumn(
RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
"Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
)
creation_time = DateNowColumn()
scope = Column(String(1024))
creation_time = DateNowColumn("CreationTime")
scope = Column("Scope", String(1024))

# User attributes bound to the refresh token
sub = Column(String(1024))
preferred_username = Column(String(255))
sub = Column("Sub", String(1024))
preferred_username = Column("PreferredUsername", String(255))
12 changes: 6 additions & 6 deletions diracx-db/src/diracx/db/sql/dummy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

class Owners(Base):
__tablename__ = "Owners"
owner_id = Column(Integer, primary_key=True, autoincrement=True)
creation_time = DateNowColumn()
name = Column(String(255))
owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True)
creation_time = DateNowColumn("CreationTime")
name = Column("Name", String(255))


class Cars(Base):
__tablename__ = "Cars"
license_plate = Column(Uuid(), primary_key=True)
model = Column(String(255))
owner_id = Column(Integer, ForeignKey(Owners.owner_id))
license_plate = Column("LicensePlate", Uuid(), primary_key=True)
model = Column("Model", String(255))
owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id))
28 changes: 16 additions & 12 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class JobDB(BaseSQLDB):
async def summary(self, group_by, search) -> list[dict[str, str | int]]:
columns = _get_columns(Jobs.__table__, group_by)

stmt = select(*columns, func.count(Jobs.JobID).label("count"))
stmt = select(*columns, func.count(Jobs.job_id).label("count"))
stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
stmt = stmt.group_by(*columns)

Expand Down Expand Up @@ -115,7 +115,7 @@ async def set_job_attributes(self, job_id, job_data):
"""TODO: add myDate and force parameters."""
if "Status" in job_data:
job_data = job_data | {"LastUpdateTime": datetime.now(tz=timezone.utc)}
stmt = update(Jobs).where(Jobs.JobID == job_id).values(job_data)
stmt = update(Jobs).where(Jobs.job_id == job_id).values(job_data)
await self.conn.execute(stmt)

async def create_job(self, original_jdl):
Expand Down Expand Up @@ -199,7 +199,7 @@ async def set_job_jdl(self, job_id, jdl):
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL

stmt = (
update(JobJDLs).where(JobJDLs.JobID == job_id).values(JDL=compressJDL(jdl))
update(JobJDLs).where(JobJDLs.job_id == job_id).values(JDL=compressJDL(jdl))
)
await self.conn.execute(stmt)

Expand Down Expand Up @@ -240,15 +240,19 @@ async def set_job_attributes_bulk(self, job_data):
)
await self.conn.execute(stmt)

async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int | str, str]:
async def get_job_jdls(
self, job_ids, original: bool = False
) -> dict[int | str, str]:
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL

if original:
stmt = select(JobJDLs.JobID, JobJDLs.OriginalJDL).where(
JobJDLs.JobID.in_(job_ids)
stmt = select(JobJDLs.job_id, JobJDLs.original_jdl).where(
JobJDLs.job_id.in_(job_ids)
)
else:
stmt = select(JobJDLs.JobID, JobJDLs.JDL).where(JobJDLs.JobID.in_(job_ids))
stmt = select(JobJDLs.job_id, JobJDLs.jdl).where(
JobJDLs.job_id.in_(job_ids)
)

return {
jobid: extractJDL(jdl)
Expand All @@ -258,9 +262,9 @@ async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int | str,

async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn:
try:
stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where(
Jobs.JobID == job_id
)
stmt = select(
Jobs.status, Jobs.minor_status, Jobs.application_status
).where(Jobs.job_id == job_id)
return LimitedJobStatusReturn(
**dict((await self.conn.execute(stmt)).one()._mapping)
)
Expand Down Expand Up @@ -298,7 +302,7 @@ async def set_job_command_bulk(self, commands):

async def delete_jobs(self, job_ids: list[int]):
"""Delete jobs from the database."""
stmt = delete(JobJDLs).where(JobJDLs.JobID.in_(job_ids))
stmt = delete(JobJDLs).where(JobJDLs.job_id.in_(job_ids))
await self.conn.execute(stmt)

async def set_properties(
Expand Down Expand Up @@ -331,7 +335,7 @@ async def set_properties(
if update_timestamp:
values["LastUpdateTime"] = datetime.now(tz=timezone.utc)

stmt = update(Jobs).where(Jobs.JobID == bindparam("job_id")).values(**values)
stmt = update(Jobs).where(Jobs.job_id == bindparam("job_id")).values(**values)
rows = await self.conn.execute(stmt, update_parameters)

return rows.rowcount
108 changes: 54 additions & 54 deletions diracx-db/src/diracx/db/sql/job/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,34 @@
class Jobs(JobDBBase):
__tablename__ = "Jobs"

JobID = Column(
job_id = Column(
"JobID",
Integer,
ForeignKey("JobJDLs.JobID", ondelete="CASCADE"),
primary_key=True,
default=0,
)
JobType = Column("JobType", String(32), default="user")
JobGroup = Column("JobGroup", String(32), default="00000000")
Site = Column("Site", String(100), default="ANY")
JobName = Column("JobName", String(128), default="Unknown")
Owner = Column("Owner", String(64), default="Unknown")
OwnerGroup = Column("OwnerGroup", String(128), default="Unknown")
VO = Column("VO", String(32))
SubmissionTime = NullColumn("SubmissionTime", DateTime)
RescheduleTime = NullColumn("RescheduleTime", DateTime)
LastUpdateTime = NullColumn("LastUpdateTime", DateTime)
StartExecTime = NullColumn("StartExecTime", DateTime)
HeartBeatTime = NullColumn("HeartBeatTime", DateTime)
EndExecTime = NullColumn("EndExecTime", DateTime)
Status = Column("Status", String(32), default="Received")
MinorStatus = Column("MinorStatus", String(128), default="Unknown")
ApplicationStatus = Column("ApplicationStatus", String(255), default="Unknown")
UserPriority = Column("UserPriority", Integer, default=0)
RescheduleCounter = Column("RescheduleCounter", Integer, default=0)
VerifiedFlag = Column("VerifiedFlag", EnumBackedBool(), default=False)
job_type = Column("JobType", String(32), default="user")
job_group = Column("JobGroup", String(32), default="00000000")
site = Column("Site", String(100), default="ANY")
job_name = Column("JobName", String(128), default="Unknown")
owner = Column("Owner", String(64), default="Unknown")
owner_group = Column("OwnerGroup", String(128), default="Unknown")
vo = Column("VO", String(32))
submission_time = NullColumn("SubmissionTime", DateTime)
reschedule_time = NullColumn("RescheduleTime", DateTime)
last_update_time = NullColumn("LastUpdateTime", DateTime)
start_exec_time = NullColumn("StartExecTime", DateTime)
heart_beat_time = NullColumn("HeartBeatTime", DateTime)
end_exec_time = NullColumn("EndExecTime", DateTime)
status = Column("Status", String(32), default="Received")
minor_status = Column("MinorStatus", String(128), default="Unknown")
application_status = Column("ApplicationStatus", String(255), default="Unknown")
user_priority = Column("UserPriority", Integer, default=0)
reschedule_counter = Column("RescheduleCounter", Integer, default=0)
verified_flag = Column("VerifiedFlag", EnumBackedBool(), default=False)
# TODO: Should this be True/False/"Failed"? Or True/False/Null?
AccountedFlag = Column(
accounted_flag = Column(
"AccountedFlag", Enum("True", "False", "Failed"), default="False"
)

Expand All @@ -64,66 +64,66 @@ class Jobs(JobDBBase):

class JobJDLs(JobDBBase):
__tablename__ = "JobJDLs"
JobID = Column(Integer, autoincrement=True, primary_key=True)
JDL = Column(Text)
JobRequirements = Column(Text)
OriginalJDL = Column(Text)
job_id = Column("JobID", Integer, autoincrement=True, primary_key=True)
jdl = Column("JDL", Text)
job_requirements = Column("JobRequirements", Text)
original_jdl = Column("OriginalJDL", Text)


class InputData(JobDBBase):
__tablename__ = "InputData"
JobID = Column(
Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
job_id = Column(
"JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
)
LFN = Column(String(255), default="", primary_key=True)
Status = Column(String(32), default="AprioriGood")
lfn = Column("LFN", String(255), default="", primary_key=True)
status = Column("Status", String(32), default="AprioriGood")


class JobParameters(JobDBBase):
__tablename__ = "JobParameters"
JobID = Column(
Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
job_id = Column(
"JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
)
Name = Column(String(100), primary_key=True)
Value = Column(Text)
name = Column("Name", String(100), primary_key=True)
value = Column("Value", Text)


class OptimizerParameters(JobDBBase):
__tablename__ = "OptimizerParameters"
JobID = Column(
Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
job_id = Column(
"JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
)
Name = Column(String(100), primary_key=True)
Value = Column(Text)
name = Column("Name", String(100), primary_key=True)
value = Column("Value", Text)


class AtticJobParameters(JobDBBase):
__tablename__ = "AtticJobParameters"
JobID = Column(
Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
job_id = Column(
"JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
)
Name = Column(String(100), primary_key=True)
Value = Column(Text)
RescheduleCycle = Column(Integer)
name = Column("Name", String(100), primary_key=True)
value = Column("Value", Text)
reschedule_cycle = Column("RescheduleCycle", Integer)


class HeartBeatLoggingInfo(JobDBBase):
__tablename__ = "HeartBeatLoggingInfo"
JobID = Column(
Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
job_id = Column(
"JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
)
Name = Column(String(100), primary_key=True)
Value = Column(Text)
HeartBeatTime = Column(DateTime, primary_key=True)
name = Column("Name", String(100), primary_key=True)
value = Column("Value", Text)
heart_beat_time = Column("HeartBeatTime", DateTime, primary_key=True)


class JobCommands(JobDBBase):
__tablename__ = "JobCommands"
JobID = Column(
Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
job_id = Column(
"JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
)
Command = Column(String(100))
Arguments = Column(String(100))
Status = Column(String(64), default="Received")
ReceptionTime = Column(DateTime, primary_key=True)
ExecutionTime = NullColumn(DateTime)
command = Column("Command", String(100))
arguments = Column("Arguments", String(100))
status = Column("Status", String(64), default="Received")
reception_time = Column("ReceptionTime", DateTime, primary_key=True)
execution_time = NullColumn("ExecutionTime", DateTime)
Loading

0 comments on commit 7f6f884

Please sign in to comment.