Skip to content

Commit

Permalink
multi-pilot changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pmanthasf committed Jul 16, 2024
1 parent 24edb61 commit 211ef98
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 99 deletions.
66 changes: 66 additions & 0 deletions examples/pq_multi_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os

import pennylane as qml
from pilot.pilot_compute_service import ExecutionEngine, PilotComputeService
from time import sleep

RESOURCE_URL = "ssh://localhost"
WORKING_DIRECTORY = os.path.join(os.environ["HOME"], "work")

pilot_compute_description = {
"resource": RESOURCE_URL,
"number_of_nodes": 2,
"cores_per_node": 10,
}


def start_pcs():
pcs = PilotComputeService(ExecutionEngine.DASK, WORKING_DIRECTORY)
for i in range(2):
pilot_compute_description["name"] = f"pilot-{i}"
pcs.create_pilot(pilot_compute_description=pilot_compute_description)
return pcs

def pennylane_quantum_circuit():
wires = 4
layers = 1
dev = qml.device('default.qubit', wires=wires, shots=None)

@qml.qnode(dev)
def circuit(parameters):
qml.StronglyEntanglingLayers(weights=parameters, wires=range(wires))
return [qml.expval(qml.PauliZ(i)) for i in range(wires)]

shape = qml.StronglyEntanglingLayers.shape(n_layers=layers, n_wires=wires)
weights = qml.numpy.random.random(size=shape)
return circuit(weights)


if __name__ == "__main__":
pcs = None
try:
# Start Pilots
pcs = start_pcs()

pilots = pcs.get_pilots()

# print pilot names
for pname in pilots:
print(pname)

# Submit tasks to pcs
tasks = []
for i in range(1000):
k = pcs.submit_task(f"task_pennylane-{i}", pennylane_quantum_circuit)
tasks.append(k)

for i in range(1000):
k = pcs.submit_task(f"task_pennylane-{i}", pennylane_quantum_circuit, pilot=pilots[0])
tasks.append(k)


pcs.wait_tasks(tasks)
finally:
if pcs:
pcs.cancel()

2 changes: 1 addition & 1 deletion pilot/job/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, job_description, resource_url, pilot_compute_description):
if urlparse(resource_url).username is not None:
self.user = urlparse(resource_url).username
self.logger.debug("URL: " + str(self.resource_url) + " Host: " + self.host)
self.id = "pilot-quantum-ssh" + str(uuid.uuid1())
self.id = "pilot-quantum-ssh-" + str(uuid.uuid1())
self.job_id = self.id
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
self.job_output = open(
Expand Down
181 changes: 111 additions & 70 deletions pilot/pilot_compute_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import os
from dask.distributed import wait
from datetime import datetime
from enum import Enum

ExecutionEngine = Enum('ExecutionEngine', ["RAY", "DASK"])




class PilotComputeService:
Expand All @@ -20,14 +25,39 @@ class PilotComputeService:
in the P* Model.
"""

def __init__(self):

def __init__(self, execution_engine=ExecutionEngine.DASK, working_directory="/tmp"):
"""Create a PilotComputeService object.
Args:
pjs_id (optional): Connect to an existing PilotComputeService.
"""

if not execution_engine in ExecutionEngine:
raise PilotAPIException(f"Invalid Execution Engine {execution_engine}")


self.execution_engine = execution_engine
self.pcs_working_directory = f"{working_directory}/pcs-{uuid.uuid4()}"
os.makedirs(self.pcs_working_directory)

self.cluster_manager = self.__get_cluster_manager(execution_engine, self.pcs_working_directory)
self.cluster_manager.start_scheduler()

self.logger = PilotComputeServiceLogger()
self.logger.info("PilotComputeService initialized.")
self.logger.info("PilotComputeService initialized.")
self.pilots = {}
self.client = None


def get_pilots(self):
return list(self.pilots.keys())

def get_pilot(self, name):
if name not in self.pilots:
raise PilotAPIException(f"Pilot {name} not found")

return self.pilots[name]

def cancel(self):
"""Cancel the PilotComputeService.
Expand All @@ -38,96 +68,84 @@ def cancel(self):
Result of the operation.
"""
self.logger.info("Cancelling PilotComputeService.")
pass
self.cluster_manager.cancel()


def create_pilot(self, pilot_compute_description):
"""
Create and initialize a PilotCompute instance based on the provided description.
:param pilot_compute_description: Dictionary containing details about the cluster to launch.
:return: Initialized PilotCompute instance.
"""
self.logger.info("Creating a new PilotCompute.")
working_directory = pilot_compute_description.get("working_directory", "/tmp")
"""
# self.validate_pcd(pilot_compute_description)

pilot_name = pilot_compute_description.get("name", f"pilot-{uuid.uuid4()}")

framework_type = pilot_compute_description.get("type")
if framework_type is None:
self.logger.error("Invalid Pilot Compute Description: type not specified")
raise PilotAPIException("Invalid Pilot Compute Description: type not specified")
worker_cluster_manager = self.__get_cluster_manager(self.execution_engine, self.pcs_working_directory)

manager = self.__get_cluster_manager(framework_type, working_directory)
self.logger.info(f"Create Pilot with description {pilot_compute_description}")
pilot_compute_description["working_directory"] = self.pcs_working_directory

batch_job = manager.submit_job(pilot_compute_description)
batch_job = worker_cluster_manager.submit_job(pilot_compute_description)
self.pilot_id = batch_job.get_id()

self.metrics_file_name = os.path.join(working_directory, f"{self.pilot_id}-metrics.csv")
self.metrics_file_name = os.path.join(self.pcs_working_directory, f"{self.pilot_id}-metrics.csv")


details = manager.get_config_data()
details = worker_cluster_manager.get_config_data()
self.logger.info(f"Cluster details: {details}")
pilot = PilotCompute(self.metrics_file_name, batch_job, cluster_manager=manager)
pilot = PilotCompute(batch_job, cluster_manager=worker_cluster_manager)


self.pilots[pilot_name] = pilot
return pilot

def validate_pcd(self, pilot_compute_description):
if "name" not in pilot_compute_description:
self.logger.error("Invalid Pilot Compute Description: name not specified")
raise PilotAPIException("Invalid Pilot Compute Description: name not specified")

if "type" not in pilot_compute_description:
self.logger.error("Invalid Pilot Compute Description: type not specified")
raise PilotAPIException("Invalid Pilot Compute Description: type not specified")




def __get_cluster_manager(self, framework_type, working_directory):
def __get_cluster_manager(self, execution_engine, working_directory):
"""
Get the appropriate ClusterManager based on the framework type.
:param framework_type: Type of the computing framework.
:param working_directory: Working directory for the cluster.
:return: ClusterManager instance.
"""
if framework_type.startswith("dask"):
return dask_cluster_manager.Manager() # Replace with appropriate manager
elif framework_type == "ray":
if execution_engine==ExecutionEngine.DASK:
return dask_cluster_manager.Manager(working_directory) # Replace with appropriate manager
elif execution_engine == "ray":
job_id = f"ray-{uuid.uuid1()}"
return ray_cluster_manager.Manager(job_id, working_directory) # Replace with appropriate manager

self.logger.error(f"Invalid Pilot Compute Description: invalid type: {framework_type}")
raise PilotAPIException(f"Invalid Pilot Compute Description: invalid type: {framework_type}")
self.logger.error(f"Invalid Pilot Compute Description: invalid type: {execution_engine}")
raise PilotAPIException(f"Invalid Pilot Compute Description: invalid type: {execution_engine}")

def task(self, func):
def get_client(self):
"""
Submit task to PilotComputeService, which can be scheduled on a group of pilots.
Returns the native client for interacting with the task execution engine (i.e. Dask or Ray) started via the Pilot-Job.
see also get_context()
"""
def wrapper(*args, **kwargs):
pass

return wrapper




class PilotAPIException(Exception):
pass


class PilotCompute(object):
"""PilotCompute (PC) representation.
This class is returned by the PilotComputeService when a new PilotCompute
(aka Pilot-Job) is created based on a PilotComputeDescription.
The PilotCompute object can be used by the application to keep track
of active PilotComputes. It has state, can be queried, can be cancelled,
and be re-initialized.
"""

def __init__(self, metrics_file_name, batch_job=None, cluster_manager=None):
self.batch_job = batch_job
self.cluster_manager = cluster_manager
self.client = None
self.metrics_fn = metrics_file_name
return self.cluster_manager.get_client()

def submit_task(self, task_name, func, *args, **kwargs):
pilot_scheduled = 'ANY'

def cancel(self):
if self.client:
self.client.close()
if self.batch_job:
self.batch_job.cancel()
if kwargs.get("pilot"):
if kwargs["pilot"] not in self.pilots:
raise PilotAPIException(f"Pilot {kwargs['pilot']} not found")
pilot_scheduled = kwargs["pilot"]
del kwargs["pilot"]

def submit_task(self, task_name, func, *args, **kwargs):
if not self.client:
self.client = self.get_client()

Expand All @@ -139,6 +157,7 @@ def submit_task(self, task_name, func, *args, **kwargs):

metrics = {
'task_id': task_name,
'pilot_scheduled': pilot_scheduled,
'submit_time': datetime.now(),
'wait_time_secs': None,
'completion_time': None,
Expand All @@ -163,7 +182,7 @@ def task_func(metrics_fn, *args, **kwargs):
metrics["execution_ms"] = time.time() - task_execution_start_time

with open(metrics_fn, 'a', newline='') as csvfile:
fieldnames = ['task_id', 'submit_time', 'wait_time_secs', 'completion_time', 'execution_ms', 'status', 'error_msg']
fieldnames = ['task_id','pilot_scheduled','submit_time', 'wait_time_secs', 'completion_time', 'execution_ms', 'status', 'error_msg']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

if csvfile.tell() == 0:
Expand All @@ -172,8 +191,11 @@ def task_func(metrics_fn, *args, **kwargs):
writer.writerow(metrics)

return result

task_future = self.client.submit(task_func, self.metrics_fn, *args, **kwargs)

if pilot_scheduled != 'ANY':
task_future = self.client.submit(task_func, self.metrics_file_name, *args, **kwargs, workers=pilot_scheduled)
else:
task_future = self.client.submit(task_func, self.metrics_file_name, *args, **kwargs)

return task_future

Expand All @@ -193,6 +215,35 @@ def run(self, func, *args, **kwargs):
print(f"Running qtask with args {args}, kwargs {kwargs}")
wrapper_func = self.task(func)
return wrapper_func(*args, **kwargs).result()

def wait_tasks(self, tasks):
wait(tasks)




class PilotAPIException(Exception):
pass


class PilotCompute(object):
"""PilotCompute (PC) representation.
This class is returned by the PilotComputeService when a new PilotCompute
(aka Pilot-Job) is created based on a PilotComputeDescription.
The PilotCompute object can be used by the application to keep track
of active PilotComputes. It has state, can be queried, can be cancelled,
and be re-initialized.
"""

def __init__(self, batch_job=None, cluster_manager=None):
self.batch_job = batch_job
self.cluster_manager = cluster_manager

def cancel(self):
if self.cluster_manager:
self.cluster_manager.cancel()

def get_state(self):
"""
Expand All @@ -207,18 +258,8 @@ def get_id(self):
def get_details(self):
return self.cluster_manager.get_config_data()

def get_client(self):
"""
Returns the native client for interacting with the task execution engine (i.e. Dask or Ray) started via the Pilot-Job.
see also get_context()
"""
return self.cluster_manager.get_client()

def wait(self):
self.cluster_manager.wait()

def wait_tasks(self, tasks):
wait(tasks)


def get_context(self, configuration=None):
Expand Down
Loading

0 comments on commit 211ef98

Please sign in to comment.