Skip to content

Commit

Permalink
[WIP] make tests pass with resource span
Browse files Browse the repository at this point in the history
  • Loading branch information
adfaure committed Dec 15, 2023
1 parent 930be4a commit 28a6483
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 57 deletions.
3 changes: 2 additions & 1 deletion oar/cli/oarnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def get_resources_for_job(session):
res = (
session.query(Resource, Job)
.filter(Job.assigned_moldable_job == AssignedResource.moldable_id)
.filter(AssignedResource.resource_id == Resource.id)
.filter(Resource.id >= AssignedResource.resource_id)
.filter(Resource.id < AssignedResource.resource_id + AssignedResource.span)
.filter(Job.state == "Running")
.order_by(Resource.id)
.all()
Expand Down
16 changes: 12 additions & 4 deletions oar/lib/basequery.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ def get_assigned_jobs_resources(self, jobs):
AssignedResource,
Job.assigned_moldable_job == AssignedResource.moldable_id,
)
.filter(Resource.id >= AssignedResource.resource_id)
.filter(Resource.id < AssignedResource.resource_id + AssignedResource.span)
.filter(
Resource.id >= AssignedResource.resource_id,
Resource.id < AssignedResource.resource_id + AssignedResource.span,
)
.filter(Job.id.in_([job.id for job in jobs]))
.order_by(Job.id.asc())
)
Expand All @@ -230,7 +232,10 @@ def get_assigned_one_job_resources(self, job):
AssignedResource,
Job.assigned_moldable_job == AssignedResource.moldable_id,
)
.join(Resource, Resource.id == AssignedResource.resource_id)
.filter(
Resource.id >= AssignedResource.resource_id,
Resource.id < AssignedResource.resource_id + AssignedResource.span,
)
# .filter(job_id_column == job.id)
)
return query
Expand All @@ -255,7 +260,10 @@ def get_jobs_resource(self, resource_id):
db = self.session
query = (
db.query(Job)
.filter(AssignedResource.resource_id == resource_id)
.filter(
Resource.id >= AssignedResource.resource_id,
Resource.id < AssignedResource.resource_id + AssignedResource.span,
)
.filter(MoldableJobDescription.id == AssignedResource.moldable_id)
.filter(MoldableJobDescription.job_id == Job.id)
)
Expand Down
21 changes: 15 additions & 6 deletions oar/lib/job_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import random
import re
from oar.lib.resource import ResourceSet

from procset import ProcSet
from sqlalchemy import distinct, func, insert, text
Expand Down Expand Up @@ -36,6 +35,7 @@
WalltimeChange,
)
from oar.lib.plugins import find_plugin_function
from oar.lib.resource import ResourceSet
from oar.lib.resource_handling import (
get_current_resources_with_suspended_job,
update_current_scheduler_priority,
Expand Down Expand Up @@ -841,7 +841,7 @@ def add_resource_job_pairs(session, moldable_id):
{
"moldable_job_id": res_mld_id.moldable_id,
"resource_id": res_mld_id.resource_id,
"span": res_mld_id.span
"span": res_mld_id.span,
}
for res_mld_id in resources_mld_ids
]
Expand Down Expand Up @@ -1209,9 +1209,12 @@ def update_scheduler_last_job_date(session, date, moldable_id):
else:
session.query(Resource).filter(
AssignedResource.moldable_id == moldable_id
).filter(Resource.id == AssignedResource.resource_id).update(
).filter(Resource.id >= AssignedResource.resource_id).filter(
Resource.id < AssignedResource.resource_id + AssignedResource.span
).update(
{Resource.last_job_date: date}, synchronize_session=False
)

session.commit()


Expand Down Expand Up @@ -1445,7 +1448,8 @@ def get_cpuset_values(session, config, cpuset_field, moldable_id):
results = (
session.query(Resource.network_address, getattr(Resource, cpuset_field))
.filter(AssignedResource.moldable_id == moldable_id)
.filter(AssignedResource.resource_id == Resource.id)
.filter(Resource.id >= AssignedResource.resource_id)
.filter(Resource.id < AssignedResource.resource_id + AssignedResource.span)
.filter(Resource.network_address != "")
.filter(text(sql_where_string))
.group_by(Resource.network_address, getattr(Resource, cpuset_field))
Expand Down Expand Up @@ -1605,7 +1609,8 @@ def get_job_current_hostnames(session, job_id):
session.query(distinct(Resource.network_address))
.filter(AssignedResource.index == "CURRENT")
.filter(MoldableJobDescription.index == "CURRENT")
.filter(AssignedResource.resource_id == Resource.id)
.filter(Resource.id >= AssignedResource.resource_id)
.filter(Resource.id < AssignedResource.resource_id + AssignedResource.span)
.filter(MoldableJobDescription.id == AssignedResource.moldable_id)
.filter(MoldableJobDescription.job_id == job_id)
.filter(Resource.network_address != "")
Expand Down Expand Up @@ -1980,7 +1985,8 @@ def get_job_host_log(session, moldable_id):
res = (
session.query(distinct(Resource.network_address))
.filter(AssignedResource.moldable_id == moldable_id)
.filter(Resource.id == AssignedResource.resource_id)
.filter(Resource.id >= AssignedResource.resource_id)
.filter(Resource.id < AssignedResource.resource_id + AssignedResource.span)
.filter(Resource.network_address != "")
.filter(Resource.type == "default")
.all()
Expand Down Expand Up @@ -2623,6 +2629,7 @@ def get_timer_armed_job(

def archive_some_moldable_job_nodes(session, config, moldable_id, hosts):
"""Sets the index fields to LOG in the table assigned_resources"""
# TODO
if config["DB_TYPE"] == "Pg":
session.query(AssignedResource).filter(
AssignedResource.moldable_id == moldable_id
Expand All @@ -2641,6 +2648,8 @@ def get_job_resources_properties(session, job_id):
.filter(Job.id == job_id)
.filter(Job.assigned_moldable_job == AssignedResource.moldable_id)
.filter(AssignedResource.resource_id == Resource.id)
.filter(Resource.id >= AssignedResource.resource_id)
.filter(Resource.id < AssignedResource.resource_id + AssignedResource.span)
.order_by(Resource.id)
.all()
)
Expand Down
72 changes: 38 additions & 34 deletions oar/lib/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,34 +119,35 @@ def get_gantt_hostname_to_wake_up(session: Session, date: int, wakeup_time: int)
hosts = [h_tpl[0] for h_tpl in hostnames]
return hosts

def get_gantt_hostname_to_wake_up_(session, date, wakeup_time):
"""Get hostname that we must wake up to launch jobs"""
hostnames = (
session.query(Resource.network_address)
.filter(GanttJobsResource.moldable_id == GanttJobsPrediction.moldable_id)
.filter(MoldableJobDescription.id == GanttJobsPrediction.moldable_id)
.filter(Job.id == MoldableJobDescription.job_id)
.filter(GanttJobsPrediction.start_time <= date + wakeup_time)
.filter(Job.state == "Waiting")
.filter(Resource.id == GanttJobsResource.resource_id)
.filter(Resource.state == "Absent")
.filter(Resource.network_address != "")
.filter(Resource.type == "default")
.filter(
(GanttJobsPrediction.start_time + MoldableJobDescription.walltime)
<= Resource.available_upto
)
.group_by(Resource.network_address)
.all()
)
hosts = [h_tpl[0] for h_tpl in hostnames]
return hosts

# TODO fail merge
# def get_gantt_hostname_to_wake_up(session, date, wakeup_time):
# """Get hostname that we must wake up to launch jobs"""
# # get save assignement
# hostnames = (
# session.query(Resource.network_address)
# .filter(GanttJobsResource.moldable_id == GanttJobsPrediction.moldable_id)
# .filter(MoldableJobDescription.id == GanttJobsPrediction.moldable_id)
# .filter(Job.id == MoldableJobDescription.job_id)
# .filter(GanttJobsPrediction.start_time <= date + wakeup_time)
# .filter(Job.state == "Waiting")
# .filter(Resource.id == GanttJobsResource.resource_id)
# .filter(Resource.state == "Absent")
# .filter(Resource.network_address != "")
# .filter(Resource.type == "default")
# .filter(
# (GanttJobsPrediction.start_time + MoldableJobDescription.walltime)
# <= Resource.available_upto
# )
# .group_by(Resource.network_address)
# .all()
# )
# hosts = [h_tpl[0] for h_tpl in hostnames]
# return hosts

def get_gantt_hostname_to_wake_up(session, date, wakeup_time):
"""Get hostname that we must wake up to launch jobs"""
# get save assignement

def get_gantt_hostname_to_wake_up_(session, date, wakeup_time):
"""Get hostname that we must wake up to launch jobs"""
hostnames = (
session.query(Resource.network_address)
.filter(GanttJobsResource.moldable_id == GanttJobsPrediction.moldable_id)
Expand All @@ -168,12 +169,8 @@ def get_gantt_hostname_to_wake_up(session, date, wakeup_time):
hosts = [h_tpl[0] for h_tpl in hostnames]
return hosts

<<<<<<< HEAD
def get_next_job_date_on_node(session: Session, hostname: str):
=======

def get_next_job_date_on_node(session, hostname):
>>>>>>> ec4caec ([test] clean after rebase)
def get_next_job_date_on_node(session: Session, hostname: str):
result = (
session.query(func.min(GanttJobsPrediction.start_time))
.filter(Resource.network_address == hostname)
Expand Down Expand Up @@ -203,8 +200,11 @@ def get_alive_nodes_with_jobs(
"""Returns the list of occupied nodes"""
result = (
session.query(distinct(Resource.network_address))
.filter(Resource.id == AssignedResource.resource_id)
.filter(AssignedResource.moldable_id == MoldableJobDescription.id)
# .filter(AssignedResource.moldable_id == MoldableJobDescription.id)
.filter(
Resource.id >= AssignedResource.resource_id,
Resource.id < AssignedResource.resource_id + AssignedResource.span,
)
.filter(MoldableJobDescription.job_id == Job.id)
.filter(
Job.state.in_(
Expand Down Expand Up @@ -398,7 +398,8 @@ def get_current_assigned_nodes(
results = (
session.query(distinct(Resource.network_address))
.filter(AssignedResource.index == "CURRENT")
.filter(Resource.id == AssignedResource.resource_id)
.filter(Resource.id >= AssignedResource.resource_id)
.filter(Resource.id < AssignedResource.resource_id + AssignedResource.span)
.filter(Resource.type == "default")
.all()
)
Expand Down Expand Up @@ -428,7 +429,10 @@ def get_node_job_to_frag(session: Session, hostname: str) -> List[int]:
.filter(AssignedResource.index == "CURRENT")
.filter(MoldableJobDescription.index == "CURRENT")
.filter(Resource.network_address == hostname)
.filter(AssignedResource.resource_id == Resource.id)
.filter(
Resource.id >= AssignedResource.resource_id,
Resource.id < AssignedResource.resource_id + AssignedResource.span,
)
.filter(AssignedResource.moldable_id == MoldableJobDescription.id)
.filter(MoldableJobDescription.job_id == Job.id)
.filter(Job.state != "Terminated")
Expand Down
36 changes: 26 additions & 10 deletions oar/lib/resource_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ def remove_resource(session, resource_id, user=None):
if state == "Dead":
results = (
session.query(Job.id, Job.assigned_moldable_job)
.filter(AssignedResource.resource_id == resource_id)
.filter(
Resource.id >= AssignedResource.resource_id,
Resource.id < AssignedResource.resource_id + AssignedResource.span,
)
.filter(AssignedResource.moldable_id == Job.assigned_moldable_job)
.all()
)
Expand Down Expand Up @@ -256,7 +259,10 @@ def get_current_assigned_job_resources(session, moldable_id):
session.query(Resource)
.filter(AssignedResource.index == "CURRENT")
.filter(AssignedResource.moldable_id == moldable_id)
.filter(Resource.id == AssignedResource.resource_id)
.filter(
Resource.id >= AssignedResource.resource_id,
Resource.id < AssignedResource.resource_id + AssignedResource.span,
)
.all()
)
return res
Expand Down Expand Up @@ -314,7 +320,8 @@ def update_resource_nextFinaudDecision(session, resource_id, finaud_decision):

def update_scheduler_last_job_date(session, date, moldable_id):
session.query(Resource).filter(AssignedResource.moldable_id == moldable_id).filter(
AssignedResource.resource_id == Resource.resource_id
Resource.id >= AssignedResource.resource_id,
Resource.id < AssignedResource.resource_id + AssignedResource.span,
).update({Resource.last_job_date: date}, synchronize_session=False)


Expand Down Expand Up @@ -380,7 +387,11 @@ def update_current_scheduler_priority(session, config, job, value, state):
session.query(distinct(getattr(Resource, f)))
.filter(AssignedResource.index == "CURRENT")
.filter(AssignedResource.moldable_id == job.assigned_moldable_job)
.filter(AssignedResource.resource_id == Resource.id)
.filter(
Resource.id >= AssignedResource.resource_id,
Resource.id
< AssignedResource.resource_id + AssignedResource.span,
)
.all()
)

Expand Down Expand Up @@ -516,12 +527,17 @@ def get_count_busy_resources(
active_moldable_job_ids = session.query(Job.assigned_moldable_job).filter(
Job.state.in_(("toLaunch", "Running", "Resuming"))
)
count_busy_resources = (
session.query(func.count(distinct(AssignedResource.resource_id)))
.filter(AssignedResource.moldable_id.in_(active_moldable_job_ids))
.scalar()
)
return count_busy_resources
count_busy_resources: List[AssignedResource] = (
session.query(AssignedResource).filter(
AssignedResource.moldable_id.in_(active_moldable_job_ids)
)
).all()

total = 0
for resource in count_busy_resources:
total += resource.span

return total


def resources_creation(
Expand Down
9 changes: 7 additions & 2 deletions tests/api/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from oar.api.url_utils import replace_query_params
from oar.kao.meta_sched import meta_schedule
from oar.lib.job_handling import insert_job, set_job_state
from oar.lib.models import FragJob, Job
from oar.lib.models import AssignedResource, FragJob, Job


def test_jobs_index(client, minimal_db_initialization):
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_app_jobs_get_one_details(client, minimal_db_initialization, setup_confi
"""GET /jobs/<id>?details=true"""
config, db = setup_config
job_id = insert_job(
minimal_db_initialization, res=[(60, [("resource_id=8", "")])], properties=""
minimal_db_initialization, res=[(60, [("resource_id=4", "")])], properties=""
)
meta_schedule(minimal_db_initialization, config, "internal")
res = client.get("/jobs/{}?details=true".format(job_id))
Expand All @@ -95,6 +95,11 @@ def test_app_jobs_get_resources(client, minimal_db_initialization, setup_config)
)
meta_schedule(minimal_db_initialization, config, "internal")
res = client.get("/jobs/{}/resources".format(job_id))

for ar in minimal_db_initialization.query(AssignedResource).all():
print(vars(ar))

print(res.json())
assert len(res.json()["items"]) == 4


Expand Down

0 comments on commit 28a6483

Please sign in to comment.