diff --git a/autodist/__init__.py b/autodist/__init__.py index d64fd19..1c43853 100644 --- a/autodist/__init__.py +++ b/autodist/__init__.py @@ -40,7 +40,7 @@ float_major_minor_tf_version )) sys.exit(1) -logging.debug('AutoDist is now running on TensorFlow {}'.format(version.VERSION)) +logging.info('AutoDist is now running on TensorFlow {}'.format(version.VERSION)) # Disable tensorflow control flow version 2 (which AutoDist does not support as of now). # Use control flow version 1 instead. diff --git a/autodist/autodist.py b/autodist/autodist.py index 69a8598..cf1396e 100644 --- a/autodist/autodist.py +++ b/autodist/autodist.py @@ -23,9 +23,9 @@ from tensorflow.python.ops import array_ops from tensorflow.python.util import tf_contextlib -from autodist.cluster import Cluster, SSHCluster -from autodist.const import ENV -from autodist.coordinator import Coordinator +from autodist.cluster import Cluster, SSHCluster, RayCluster +from autodist.const import ENV, DEFAULT_WORKING_DIR +from autodist.coordinator import Coordinator, RayCoordinator from autodist.graph_item import GraphItem from autodist.kernel.device.resolver import DeviceResolver from autodist.kernel.graph_transformer import GraphTransformer @@ -64,17 +64,27 @@ class _AutoDistInterface: Ancestor of _V1Graph, _V2Graph, and _V2Eager -- the different ways to run TF code. """ - def __init__(self, resource_spec_file, strategy_builder=None): + def __init__(self, resource_spec_file=None, strategy_builder=None): set_default_autodist(self) - self._resource_spec = ResourceSpec(resource_file=resource_spec_file) self._strategy_builder = strategy_builder or PSLoadBalancing() self._original_graph_item = None self._transformed_graph_item = None self._remapper = None self._built = None # Ref to the built GraphDef - - self._cluster: Cluster = SSHCluster(self._resource_spec) # which can be also defined with strategy + if resource_spec_file: + self._resource_spec = ResourceSpec(resource_file=resource_spec_file) + self._cluster: Cluster = SSHCluster(self._resource_spec) # which can be also defined with strategy + else : + # init ray cluster here and return the resource spec of Ray cluster + if IS_AUTODIST_CHIEF: + self._cluster: Cluster = RayCluster() + self._resource_spec = self._cluster.get_resource_spec() + else: + # create resource spec with the file in DEFAULT_WORKING_DIR + self._resource_spec = ResourceSpec( + resource_file=os.path.join(DEFAULT_WORKING_DIR, "resource_spec.yml")) + self._cluster: Cluster = RayCluster(self._resource_spec) self._coordinator: Coordinator @tf_contextlib.contextmanager @@ -122,7 +132,11 @@ def _setup(self, strategy): if IS_AUTODIST_CHIEF: # we should only have one single coordinator for one single AutoDist() instance scope, # even though we could have multiple strategies. - self._coordinator = Coordinator(strategy=strategy, cluster=self._cluster) + if isinstance(self._cluster, RayCluster): + # Switch to Ray coordinator temporarily + self._coordinator = RayCoordinator(strategy=strategy, cluster=self._cluster) + else: + self._coordinator = Coordinator(strategy=strategy, cluster=self._cluster) self._cluster.start() self._coordinator.launch_clients() logging.info('Current PID {} belongs to address {}'.format(os.getpid(), self._cluster.get_local_address())) diff --git a/autodist/cluster.py b/autodist/cluster.py index 77f65d0..d4f160d 100644 --- a/autodist/cluster.py +++ b/autodist/cluster.py @@ -52,6 +52,7 @@ class Cluster(metaclass=ABCMeta): """Cluster manager for TensorFlow servers.""" def __init__(self, resource_spec: ResourceSpec): + self._resource_spec = resource_spec self.cluster_spec = self._get_default_cluster_spec(resource_spec) self._cpu_devices = self._get_node_cpu_devices(resource_spec) self._gpu_devices = self._get_node_gpu_devices(resource_spec) @@ -95,6 +96,9 @@ def _get_node_gpu_devices(resource_spec: ResourceSpec): _gpu_devices.setdefault(device[0].split(':')[0], []).append(':'.join(device[0].split(':')[1:])) return _gpu_devices + def get_resource_spec(self): + return self._resource_spec + def is_chief(self, address=None): """ Check whether an address is chief or not. @@ -196,14 +200,14 @@ def start(self): self.subprocesses.append(proc) # The above line immediately follows the Popen # to ensure no gap for termination failure due to the empty proc list. - logging.debug('$ local tf.server started at {}: job_name={} task_index={}'.format( + logging.info('$ local tf.server started at {}: job_name={} task_index={}'.format( full_address, job_name, task_index )) else: # remote - self.remote_pre_start_tf_server(address, tf_server_starter_filepath=module_file) + self._remote_pre_start_tf_server(address, tf_server_starter_filepath=module_file) file = os.path.join(DEFAULT_WORKING_DIR, os.path.basename(module_file)) bash = envs + envs_cuda + ['python', '-u', file] + args - logging.info("Launching tf.server on %s" % address) + logging.info(f"Launching tf.server on {address} with {bash}") proc = self.remote_exec(bash, hostname=address) # The above line immediately follows the Popen # to ensure no gap for termination failure due to the empty proc list. @@ -215,7 +219,7 @@ def terminate(self): for p in self.subprocesses: os.killpg(os.getpgid(p.pid), signal.SIGTERM) - def remote_pre_start_tf_server(self, hostname, tf_server_starter_filepath, working_dir=DEFAULT_WORKING_DIR): + def _remote_pre_start_tf_server(self, hostname, tf_server_starter_filepath, working_dir=DEFAULT_WORKING_DIR): """ Prepare to start a TensorFlow server remotely. @@ -372,3 +376,223 @@ def remote_copy(self, local_path, remote_path, hostname): with self._get_sftp_client(hostname) as sftp: sftp.put(localpath=local_path, remotepath=os.path.join(remote_path, os.path.basename(local_path))) + + +# Ray related contents +import asyncio +import ray +import yaml +import socket +import time + + +@ray.remote +class NodeActor(object): + def __init__(self): + self._ready_event = asyncio.Event() + self._proc_dict = {} + + def send(self, clear=False): + self._ready_event.set() + if clear: + self._ready_event.clear() + + async def wait(self, should_wait=True): + if should_wait: + await self._ready_event.wait() + + def get_node_ip(self): + return ray._private.services.get_node_ip_address() + + def execute_cmd(self, args): + cmd_list = [] + full_cmd = ' '.join(cmd_list + args) + logging.info(f"exec remote cmd {full_cmd}") + # pylint: disable=subprocess-popen-preexec-fn + proc = subprocess.Popen(full_cmd, shell=True, preexec_fn=os.setsid) + pid = proc.pid + self._proc_dict[pid] = proc + return pid + + def launch_tf_server(self, args): + cmd_list = [] + full_cmd = ' '.join(cmd_list + args) + logging.info(f"launch {full_cmd}") + # pylint: disable=subprocess-popen-preexec-fn + proc = subprocess.Popen(full_cmd, shell=True, preexec_fn=os.setsid, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + while True: + line = proc.stderr.readline().decode() + print(f"LOG {line}") + if "Started server with target" in line: + break + + pid = proc.pid + self._proc_dict[pid] = proc + return pid + + def join(self, pid): + self._proc_dict[pid].wait() + del self._proc_dict[pid] + + def kill(self): + logging.info('Terminating the Ray node...') + for pid, p in self._proc_dict.items(): + os.killpg(os.getpgid(pid), signal.SIGTERM) + self._proc_dict.clear() + + def file_write(self, remote_path, data): + with open(remote_path, 'w') as f: + f.write(data) + + +class RayCluster(Cluster): + """An AutoDist Cluster Based on Ray.""" + + def __init__(self, resource_spec=None): + # this should be the autodist chief + if not resource_spec: + # init the ray cluster + # create the resource spec from ray nodes + print("Init Ray Cluster") + self._init_ray_actors() + resource_spec = ResourceSpec( + resource_file=os.path.join(DEFAULT_WORKING_DIR, "resource_spec.yml")) + + super().__init__(resource_spec) + + def _init_ray_actors(self): + ray_head_address = os.getenv("RAY_HEAD_ADDRESS") + if ray_head_address and not ray.is_initialized(): + ray.init(address=ray_head_address) + + self._actor_dict = {} + # temporarily just look for GPU resources + gpu_list = [] + for node in ray.nodes(): + print(f"Node {node}") + node_ip = node["NodeManagerAddress"] + gpu_count = node["Resources"].get("GPU") + if not gpu_count or not node["Alive"]: + continue + gpu_list.append((gpu_count, node_ip)) + + gpu_list = [t for t in reversed(sorted(gpu_list))] + print(f"gpu_list {gpu_list}") + for gpu_count, _ in gpu_list: + actor = NodeActor.options(num_gpus=gpu_count).remote() + actor.wait.remote() + node_ip = ray.get(actor.get_node_ip.remote()) + print(f"node ip {node_ip}") + self._actor_dict[node_ip] = actor + + chief_address = ray._private.services.get_node_ip_address() + print(f"chief address: {chief_address} ") + resource_dict = {} + resource_dict["nodes"] = [] + for gpu_count, node_ip in gpu_list: + node_dict = {"address": node_ip, "gpus": [i for i in range(int(gpu_count))] } + if node_ip == chief_address: + node_dict["chief"] = True + resource_dict["nodes"].append(node_dict) + + resource_yaml = yaml.dump(resource_dict) + print("resource yaml") + print(resource_yaml) + for node_ip, actor in self._actor_dict.items(): + ray.get(actor.file_write.remote( + os.path.join(DEFAULT_WORKING_DIR, "resource_spec.yml"), resource_yaml)) + + def remote_exec(self, args, hostname): + """ + Execute a bash script remotely. + + Args: + args (list): bash commands + hostname (str): means ip address in a ray cluster + + Returns: + Process: process handle + """ + return ray.get(self._actor_dict[hostname].execute_cmd.remote(args)) + + def remote_join(self, pid, hostname): + return ray.get(self._actor_dict[hostname].join.remote(pid)) + + def remote_file_write(self, remote_path, data, hostname): + """ + Write a remote file. + + Args: + remote_path (str): remote file path + data (str): data to be written + hostname (str): host name or address + """ + ray.get(self._actor_dict[hostname].file_write.remote(remote_path, data)) + + def remote_copy(self, local_path, remote_path, hostname): + """ + Copy a file to a remote directory. + + Args: + local_path (str): local file path to be copied + remote_path (str): remote directory path + hostname (str): host name or address + """ + # use remote_file_write internally + with open(local_path,"r") as f: + data = f.read() + + self.remote_file_write(os.path.join(remote_path, os.path.basename(local_path)), data, hostname) + + # pylint: disable=too-many-locals + def start(self): + """ + Start tf.servers on all nodes. + + Note that this only runs (and only should run) on the chief node. + """ + # pylint: disable=import-outside-toplevel + from autodist.utils import server_starter + + # atexit registration should be placed + # - before the beginning of the start + # (to ensure the clean termination if the start fails in its half way); and + # - at the same module as the start + # (to follow the python assumption that + # lower level modules will normally be imported + # before higher level modules and thus must be cleaned up later). + atexit.register(self.terminate) + envs = {ENV.AUTODIST_MIN_LOG_LEVEL.name: 'ERROR'} + envs = ['{}={}'.format(k, v) for k, v in envs.items()] + module_name = server_starter.__name__ + module_file = server_starter.__file__ + + for job_name, tasks in self.cluster_spec.items(): + for task_index, full_address in enumerate(tasks): + address, port = full_address.split(':') + args = ['--job_name=%s' % job_name, '--task_index=%d' % task_index, + '--cpu_device_num=%d' % len(self._cpu_devices[address])] + if address in self._gpu_devices: + envs_cuda = [] + else: + envs_cuda = ['CUDA_VISIBLE_DEVICES=""'] + + self._remote_pre_start_tf_server(address, tf_server_starter_filepath=module_file) + file = os.path.join(DEFAULT_WORKING_DIR, os.path.basename(module_file)) + bash = envs + envs_cuda + ['python', '-u', file] + args + logging.info(f"Launching tf.server on {address} with {bash}") + # self.remote_exec(bash, hostname=address) + ray.get(self._actor_dict[address].launch_tf_server.remote(bash)) + + # time.sleep(10) + + def terminate(self): + # call actor methods to cleanup tf servers + # cleanup ray actors + for node_ip, actor in self._actor_dict.items(): + ray.get(actor.kill.remote()) + ray.get(actor.send.remote()) + + self._actor_dict.clear() + ray.shutdown() diff --git a/autodist/coordinator.py b/autodist/coordinator.py index 3ef4241..71a82de 100644 --- a/autodist/coordinator.py +++ b/autodist/coordinator.py @@ -59,6 +59,8 @@ def launch_clients(self): for device_string in self._strategy.graph_config.replicas ] replica_hosts = {d.host_address for d in replica_devices} + print("REPLICA HOSTS " + str(replica_hosts)) + print("SYS ARGV " + str(sys.argv)) # Assumption: Master node must run one replica. # assert any([is_local_address(h) for h in replica_hosts]) @@ -86,6 +88,7 @@ def launch_clients(self): remote_path=DEFAULT_SERIALIZATION_DIR, hostname=replica_host ) + logging.info(f"cmd {cmd}") proc = self.cluster.remote_exec(cmd, hostname=replica_host) self.threads.append(self._proc_wait_async(proc)) @@ -108,3 +111,66 @@ def run_subprocess_in_thread(proc, on_exit): thread.start() # returns immediately after the thread starts return thread + +# should pair up with RayCluster +class RayCoordinator(Coordinator): + def __init__(self, strategy, cluster): + super().__init__(strategy, cluster) + self._pids = {} + + def launch_clients(self): + """ + Launch the user's code on each worker. + + Sets environment variables so that we run the correct AutoDist code paths on workers. + (i.e., the non-chief code-paths). + + Store each new process created into the class so they can be monitored with `join`. + """ + atexit.register(self.join) + + replica_devices = [ + DeviceSpec.from_string(device_string) + for device_string in self._strategy.graph_config.replicas + ] + replica_hosts = {d.host_address for d in replica_devices} + print("REPLICA HOSTS " + str(replica_hosts)) + print("SYS ARGV " + str(sys.argv)) + + # Assumption: Master node must run one replica. + # assert any([is_local_address(h) for h in replica_hosts]) + + for replica_host in replica_hosts: + # Run the process + if not self.cluster.is_chief(replica_host): + # Build the command + env = { + ENV.AUTODIST_WORKER.name: replica_host, + ENV.AUTODIST_STRATEGY_ID.name: self._strategy.id, + ENV.AUTODIST_MIN_LOG_LEVEL.name: ENV.AUTODIST_MIN_LOG_LEVEL.val, + ENV.AUTODIST_IS_TESTING.name: ENV.AUTODIST_IS_TESTING.val, + ENV.AUTODIST_PATCH_TF.name: ENV.AUTODIST_PATCH_TF.val, + ENV.AUTODIST_INTERNAL_TF.name: ENV.AUTODIST_INTERNAL_TF.val, + ENV.SYS_DATA_PATH.name: ENV.SYS_DATA_PATH.val, + ENV.SYS_RESOURCE_PATH.name: ENV.SYS_RESOURCE_PATH.val, + } + cmd_env = ['{}={}'.format(k, v) for k, v in env.items()] + cmd_main = ["python"] + sys.argv + cmd = cmd_env + cmd_main + + self.cluster.remote_copy( + local_path=self._strategy.path, + remote_path=DEFAULT_SERIALIZATION_DIR, + hostname=replica_host + ) + logging.info(f"cmd {cmd}") + pid = self.cluster.remote_exec(cmd, hostname=replica_host) + self._pids[replica_host] = pid + + def join(self): + """Wait for all subprocesses of remote workers to be completed.""" + logging.info('Joining workers...') + for hostname, pid in self._pids.items(): + self.cluster.remote_join(pid, hostname) + self._pids.clear() + diff --git a/autodist/resource_spec.py b/autodist/resource_spec.py index 5945aef..2ab147f 100644 --- a/autodist/resource_spec.py +++ b/autodist/resource_spec.py @@ -204,8 +204,8 @@ def _parse_node(self, node, num_nodes): gpu = DeviceSpec(host_address, host_cpu, DeviceType.GPU, gpu_index) self._add_device(gpu) self.__ssh_group[host_address] = node.get('ssh_config') - if self.__ssh_group[host_address] is None and self.__chief_address != host_address: - raise ValueError("Need to define SSH groups for all non-chief nodes.") + # if self.__ssh_group[host_address] is None and self.__chief_address != host_address: + # raise ValueError("Need to define SSH groups for all non-chief nodes.") # handle network bandwidth (optional) if node.get('network_bandwidth'): self.__network_bandwidth[host_address] = node.get('network_bandwidth') diff --git a/examples/linear_regression_ray.py b/examples/linear_regression_ray.py new file mode 100644 index 0000000..dc6c20f --- /dev/null +++ b/examples/linear_regression_ray.py @@ -0,0 +1,73 @@ +import sys +import os + +import numpy as np +import tensorflow as tf + +from autodist import AutoDist +from autodist.strategy import PS, PSLoadBalancing, PartitionedPS, AllReduce, Parallax + +def main(_): + autodist = AutoDist(strategy_builder=AllReduce(128)) + + TRUE_W = 3.0 + TRUE_b = 2.0 + NUM_EXAMPLES = 1000 + EPOCHS = 10 + + inputs = np.random.randn(NUM_EXAMPLES) + noises = np.random.randn(NUM_EXAMPLES) + outputs = inputs * TRUE_W + TRUE_b + noises + + class MyIterator: + + def initialize(self): + return tf.zeros(1) + + def get_next(self): + # a fake one + return inputs + + inputs_iterator = MyIterator() + print('I am going to a scope.') + with tf.Graph().as_default() as g, autodist.scope(): + # x = placeholder(shape=[NUM_EXAMPLES], dtype=tf.float32) + + W = tf.Variable(5.0, name='W', dtype=tf.float64) + b = tf.Variable(0.0, name='b', dtype=tf.float64) + + def train_step(input): + + def y(x): + return W * x + b + + def l(predicted_y, desired_y): + return tf.reduce_mean(tf.square(predicted_y - desired_y)) + + major_version, _, _ = tf.version.VERSION.split('.') + if major_version == '1': + optimizer = tf.train.GradientDescentOptimizer(0.01) + else: + optimizer = tf.optimizers.SGD(0.01) + + with tf.GradientTape() as tape: + loss = l(y(input), outputs) + vs = [W, b] + + # gradients = tape.gradient(target=loss, sources=vs) + gradients = tf.gradients(loss, vs) + + train_op = optimizer.apply_gradients(zip(gradients, vs)) + return loss, train_op, b + + fetches = train_step(inputs_iterator.get_next()) + session = autodist.create_distributed_session() + for epoch in range(EPOCHS): + l, t, b = session.run(fetches) + print('node: {}, loss: {}\nb:{}'.format(autodist._cluster.get_local_address(), l, b)) + + print('I am out of scope') + + +main(sys.argv) +