diff --git a/autodist/__init__.py b/autodist/__init__.py index d64fd19..39a91a3 100644 --- a/autodist/__init__.py +++ b/autodist/__init__.py @@ -26,7 +26,8 @@ logging.set_verbosity(ENV.AUTODIST_MIN_LOG_LEVEL.val) # Enforce abspath -if sys.argv and os.path.exists(sys.argv[0]) and not os.path.isabs(sys.argv[0]): +if sys.argv and os.path.exists(sys.argv[0]) and not os.path.isabs(sys.argv[0]) \ + and "ray" not in sys.modules: logging.error('AutoDist requires the script path "{}" to be an absolute path to be shared across workers. ' 'Now exit.'.format(sys.argv[0])) sys.exit(1) diff --git a/autodist/autodist.py b/autodist/autodist.py index 69a8598..170cd30 100644 --- a/autodist/autodist.py +++ b/autodist/autodist.py @@ -37,12 +37,20 @@ from autodist.strategy.ps_lb_strategy import PSLoadBalancing from autodist.utils import logging -IS_AUTODIST_WORKER = bool(ENV.AUTODIST_WORKER.val) -IS_AUTODIST_CHIEF = not IS_AUTODIST_WORKER _DEFAULT_AUTODIST = {} +def IS_AUTODIST_WORKER(): # noqa + """True if current worker is just a worker.""" + return bool(ENV.AUTODIST_WORKER.val) + + +def IS_AUTODIST_CHIEF(): # noqa + """True if current worker is the Chief.""" + return not IS_AUTODIST_WORKER() + + def set_default_autodist(o): """Set the AutoDist object the scope of which you are in.""" global _DEFAULT_AUTODIST @@ -64,9 +72,12 @@ class _AutoDistInterface: Ancestor of _V1Graph, _V2Graph, and _V2Eager -- the different ways to run TF code. """ - def __init__(self, resource_spec_file, strategy_builder=None): + def __init__(self, resource_spec_file=None, strategy_builder=None, resource_spec=None, strategy=None): set_default_autodist(self) - self._resource_spec = ResourceSpec(resource_file=resource_spec_file) + if resource_spec_file is not None: + self._resource_spec = ResourceSpec(resource_file=resource_spec_file) + else: + self._resource_spec = resource_spec self._strategy_builder = strategy_builder or PSLoadBalancing() self._original_graph_item = None @@ -76,6 +87,7 @@ def __init__(self, resource_spec_file, strategy_builder=None): self._cluster: Cluster = SSHCluster(self._resource_spec) # which can be also defined with strategy self._coordinator: Coordinator + self._strategy = strategy # Directly passed strategy to Ray workers if not None @tf_contextlib.contextmanager def _scope(self): @@ -99,9 +111,11 @@ def build_strategy(self): def _build_or_load_strategy(self): self._original_graph_item.prepare() - if IS_AUTODIST_CHIEF: + if IS_AUTODIST_CHIEF(): s = self.build_strategy() s.serialize() + elif self._strategy is not None: + s = self._strategy else: strategy_id = ENV.AUTODIST_STRATEGY_ID.val assert strategy_id @@ -119,7 +133,7 @@ def _compile_strategy(self, strategy): def _setup(self, strategy): """Prepare for the execution.""" - if IS_AUTODIST_CHIEF: + if IS_AUTODIST_CHIEF() and not ENV.AUTODIST_RAY_BACKEND.val: # we should only have one single coordinator for one single AutoDist() instance scope, # even though we could have multiple strategies. self._coordinator = Coordinator(strategy=strategy, cluster=self._cluster) @@ -148,6 +162,7 @@ def _build(self): self._transformed_graph_item = graph_transformer.transform() self._remapper = Remapper(graph_transformer, self._transformed_graph_item) self._built = self._original_graph_item.graph.as_graph_def() + self._strategy = strategy def is_built(self): """ diff --git a/autodist/const.py b/autodist/const.py index 235f803..318bdfa 100644 --- a/autodist/const.py +++ b/autodist/const.py @@ -80,6 +80,7 @@ class ENV(Enum): AUTODIST_INTERNAL_TF = auto(), lambda v: (v or "False") == "True" # noqa: E731 SYS_DATA_PATH = auto(), lambda v: v or "" # noqa: E731 SYS_RESOURCE_PATH = auto(), lambda v: v or "" # noqa: E731 + AUTODIST_RAY_BACKEND = auto(), lambda v: True if v == "True" else False # noqa: E731 @property def val(self): diff --git a/autodist/ray/__init__.py b/autodist/ray/__init__.py new file mode 100644 index 0000000..ef6503d --- /dev/null +++ b/autodist/ray/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2021 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .backend import TFTrainer +from .backend import TFRunner diff --git a/autodist/ray/backend.py b/autodist/ray/backend.py new file mode 100644 index 0000000..8734039 --- /dev/null +++ b/autodist/ray/backend.py @@ -0,0 +1,241 @@ +# Copyright 2021 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Autodist Ray Backend, includes TFRunner and TFTrainer implementations.""" +import os +import tensorflow as tf +import tensorflow.compat.v1 as v1 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.training.server_lib import ClusterSpec, Server +import ray + +from autodist import AutoDist +from autodist.const import ENV, DEFAULT_GROUP_LEADER +from autodist.resource_spec import ResourceSpec +from autodist.resource_spec import DeviceSpec +from autodist.cluster import Cluster +from autodist.checkpoint.saver import Saver as autodist_saver + + +@ray.remote +class TFServer: + """Tensorflow Server Actor responsible for executing the actual ops.""" + + @staticmethod + def launch(cluster_spec, job_name, task_index, num_cpu_device): + """Start the TF server. This call blocks.""" + experimental = config_pb2.ConfigProto.Experimental( + collective_nccl=True, + collective_group_leader=DEFAULT_GROUP_LEADER) + s = Server( + ClusterSpec(cluster_spec), + job_name=job_name, + task_index=task_index, + config=config_pb2.ConfigProto( + experimental=experimental, + device_count={"CPU": num_cpu_device}, + inter_op_parallelism_threads=0, + intra_op_parallelism_threads=0, + ) + ) + s.join() + + +class TFRunner: + """Each TFRunner including master represents one replica of the training job.""" + + def __init__(self, # pylint: disable=too-many-arguments + strategy_builder, + strategy, + train_step, + model_fn, + input_fn, + env, + resource_spec): + self._epoch = 0 + # Setup environment vars for the new runner + for var, val in env.items(): + if isinstance(val, bool): + os.environ[var] = "True" if val else "False" + else: + os.environ[var] = val + + # Set Ray backend to True + os.environ[ENV.AUTODIST_RAY_BACKEND.name] = "True" + + # We either pass a strategy_builder or directly a strategy + self._autodist = AutoDist(strategy_builder=strategy_builder, + strategy=strategy, + resource_spec=resource_spec) + self._g = v1.Graph() + with self._g.as_default(), self._autodist.scope(): + # model_fn and input_fn can return multiple things, pack and + # unpack them into the step function + models = model_fn() + inputs = input_fn() + if isinstance(inputs, tuple): + iterators = (i.get_next() if hasattr(i, 'get_next') + else i for i in inputs) + else: + iterators = (inputs.get_next() if hasattr(inputs, 'get_next') + else inputs,) + + if not isinstance(models, tuple): + models = (models,) + + # Create saver before creating the session + self._saver = autodist_saver() + self._fetches = train_step(*models, *iterators) + self._session = self._autodist.create_distributed_session() + + def step(self): + """Take one training step.""" + self._epoch += 1 + return self._session.run(self._fetches) + + def get_strategy(self): + """Fetch the current strategy.""" + return self._autodist._strategy + + def save(self, checkpoint_dir, checkpoint_prefix=""): + """Save a TF checkpoint.""" + self._saver.save(self._session, checkpoint_dir + checkpoint_prefix, global_step=self._epoch + 1) + self._saver.restore(self._session, tf.train.latest_checkpoint(checkpoint_dir)) + + def restore(self, checkpoint_dir): + """Restore the checkpoint from the directory.""" + with self._g.as_default(), self._autodist.scope(): + self._saver.restore(self._session, tf.train.latest_checkpoint(checkpoint_dir)) + + +class TFTrainer: + """TFTrainer represents one training job.""" + + def __init__(self, strategy_builder, train_step, model_fn, input_fn): + + # Set Ray backend + os.environ[ENV.AUTODIST_RAY_BACKEND.name] = "True" + + # Go from resource_info -> ResourceSpec -> ClusterSpec + self._resource_spec = ResourceSpec( + resource_info=self._get_resource_info()) + + self._replicas = [] # Replica actors, also contains master + + # Start TF Servers on each node of the cluster + self._start_tf_servers() + + def spawn_replica(replica_host, strategy_builder, strategy=None, env=None): + # Enforce actor placement on the provided host + runner = ray.remote(resources={f"node:{replica_host}": 0.01}, + num_cpus=1)(TFRunner) + return runner.remote(strategy_builder, + strategy, + train_step, + model_fn, + input_fn, + env if env is not None else {}, + self._resource_spec) + + # Start the master worker, let it build a strategy from the strategy builder + self._master = spawn_replica(ray._private.services.get_node_ip_address(), strategy_builder) + + # Add master to replicas list because it also acts as one of the clients + self._replicas.append((ray._private.services.get_node_ip_address(), self._master)) + + # Fetch the strategy directly from the master + strategy = ray.get(self._master.get_strategy.remote()) + + assert strategy is not None + + # Spawn clients based on the strategy built by master + replica_devices = [ + DeviceSpec.from_string(device_string) + for device_string in strategy.graph_config.replicas + ] + + replica_hosts = {d.host_address for d in replica_devices} + for replica_host in replica_hosts: + if replica_host != ray._private.services.get_node_ip_address(): + # Only non-master replicas + env = { + ENV.AUTODIST_WORKER.name: replica_host, + ENV.AUTODIST_MIN_LOG_LEVEL.name: ENV.AUTODIST_MIN_LOG_LEVEL.val, + ENV.AUTODIST_IS_TESTING.name: ENV.AUTODIST_IS_TESTING.val, + ENV.AUTODIST_PATCH_TF.name: ENV.AUTODIST_PATCH_TF.val, + ENV.AUTODIST_INTERNAL_TF.name: ENV.AUTODIST_INTERNAL_TF.val, + ENV.SYS_DATA_PATH.name: ENV.SYS_DATA_PATH.val, + ENV.SYS_RESOURCE_PATH.name: ENV.SYS_RESOURCE_PATH.val, + } + self._replicas.append((replica_host, spawn_replica(replica_host, None, strategy, env))) + + def _start_tf_servers(self): + """Launch TF server actors on each Ray nodes.""" + cluster_spec = Cluster._get_default_cluster_spec(self._resource_spec) + cpu_devices = Cluster._get_node_cpu_devices(self._resource_spec) + gpu_devices = Cluster._get_node_gpu_devices(self._resource_spec) + + self._servers = [] + for job_name, tasks in cluster_spec.items(): + for task_index, full_address in enumerate(tasks): + node_ip, _ = full_address.split(':') + # Make sure we spawn one server per Ray node + # Give it all the GPUs on that node + server = TFServer.options(resources={f"node:{node_ip}": 0.01}, + num_gpus=len(gpu_devices.get(node_ip, []))).remote() + self._servers.append(server) + server.launch.remote(cluster_spec, + job_name, + task_index, + len(cpu_devices[node_ip])) + + @staticmethod + def _get_resource_info(): + """Create resource_info from resources available to the Ray cluster.""" + resource_info = {} + resource_info["nodes"] = [] + chief_address = ray._private.services.get_node_ip_address() + for node in ray.nodes(): + node_ip = node["NodeManagerAddress"] + cpu_count = node["Resources"].get("CPU") + gpu_count = node["Resources"].get("GPU") + if not node["Alive"] or (cpu_count is None and gpu_count is None): + continue + node = {"address": node_ip, + "cpus": [0] if cpu_count else [], + "gpus": list(range(int(gpu_count))) if gpu_count else []} + if node_ip == chief_address: + node["chief"] = True + resource_info["nodes"].append(node) + return resource_info + + def train(self): + """Runs one training epoch.""" + return dict(zip([replica[0] for replica in self._replicas], + ray.get([replica[1].step.remote() for replica in self._replicas]))) + + def save(self, checkpoint_dir, checkpoint_prefix): + """Save a checkpoint with prefix.""" + ray.get(self._master.save.remote(checkpoint_dir, checkpoint_prefix)) + + def restore(self, checkpoint_dir): + """Restore the latest checkpoint from directory.""" + ray.get(self._master.restore.remote(checkpoint_dir)) + + def shutdown(self): + """Shutdown all the actors and the training job.""" + for server in self._servers: + ray.kill(server) + for replica in self._replicas: + ray.kill(replica[1]) diff --git a/autodist/resource_spec.py b/autodist/resource_spec.py index 5945aef..96cbc66 100644 --- a/autodist/resource_spec.py +++ b/autodist/resource_spec.py @@ -22,6 +22,7 @@ from autodist.utils import logging from autodist.utils.network import is_loopback_address, is_local_address +from autodist.const import ENV class Connectivity(Enum): @@ -52,7 +53,7 @@ class ResourceSpec: This would allow for even more intelligent strategy generation. """ - def __init__(self, resource_file=None): + def __init__(self, resource_file=None, resource_info=None): """ Construct a device graph containing the connectivity between devices. @@ -61,6 +62,7 @@ def __init__(self, resource_file=None): Args: resource_file (string, optional): path to the file containing the resource info. Defaults to None. + resource_info (optional): resource_info object, used if resource_file is None """ # protected properties self.__devices = dict() @@ -75,7 +77,10 @@ def __init__(self, resource_file=None): self.__network_bandwidth = dict() # set self.__devices - self._from_resource_info(resource_file) + if resource_info is not None: + self._from_resource_info(resource_info) + else: + self._from_resource_info_file(resource_file) @property def chief(self): @@ -125,7 +130,7 @@ def node_gpu_devices(self): @property def node_cpu_devices(self): - """Node_address-to-device_string mapping of all cpu devices.""" + """Node_address-to-device_string mapping of all cpu devices.""" _cpu_devices = dict() for device in self.cpu_devices: _cpu_devices.setdefault(device[0].split(':')[0], []).append(device[0]) @@ -157,12 +162,7 @@ def _add_device(self, device_spec): if device_spec.name_string() not in self.__devices: self.__devices[device_spec.name_string()] = device_spec - def _from_resource_info(self, resource_file=None): - if resource_file is None: - # TODO(Hao): To deal with single-node GPUs - return - - resource_info = yaml.safe_load(open(resource_file, 'r')) + def _from_resource_info(self, resource_info): num_nodes = len(resource_info.get('nodes', {})) for node in resource_info.pop('nodes', {}): @@ -182,6 +182,14 @@ def _from_resource_info(self, resource_file=None): if self.__chief_address is None: raise ValueError('Must provide "chief: true" in one of the nodes in resource spec.') + def _from_resource_info_file(self, resource_file=None): + if resource_file is None: + # TODO(Hao): To deal with single-node GPUs + return + + resource_info = yaml.safe_load(open(resource_file, 'r')) + self._from_resource_info(resource_info) + def _parse_node(self, node, num_nodes): host_address = node['address'] if is_loopback_address(host_address) and num_nodes > 1: @@ -204,7 +212,8 @@ def _parse_node(self, node, num_nodes): gpu = DeviceSpec(host_address, host_cpu, DeviceType.GPU, gpu_index) self._add_device(gpu) self.__ssh_group[host_address] = node.get('ssh_config') - if self.__ssh_group[host_address] is None and self.__chief_address != host_address: + if self.__ssh_group[host_address] is None and self.__chief_address != host_address \ + and not ENV.AUTODIST_RAY_BACKEND.val: raise ValueError("Need to define SSH groups for all non-chief nodes.") # handle network bandwidth (optional) if node.get('network_bandwidth'): diff --git a/docs/usage/tutorials/save-restore.md b/docs/usage/tutorials/save-restore.md index dcaa8b5..7bf3a17 100644 --- a/docs/usage/tutorials/save-restore.md +++ b/docs/usage/tutorials/save-restore.md @@ -114,7 +114,7 @@ from autodist.autodist import IS_AUTODIST_CHIEF # Some training code ... -if IS_AUTODIST_CHIEF: +if IS_AUTODIST_CHIEF(): saver.save(session, checkpoint_name, global_step=epoch) print('Checkpoint saved at {%s}' % checkpoint_name) else: diff --git a/examples/benchmark/README.md b/examples/benchmark/README.md index bd21337..314896d 100644 --- a/examples/benchmark/README.md +++ b/examples/benchmark/README.md @@ -18,6 +18,19 @@ The instruction for generating the training data and setting up the pre-trained ``` python ${REAL_SCRIPT_PATH}/bert.py -input_files=${REAL_DATA_PATH}/sample_data_tfrecord/*.tfrecord --bert_config_file=${REAL_DATA_PATH}/uncased_L-24_H-1024_A-16/bert_config --num_train_epochs=1 --learning_rate=5e-5 --steps_per_loop=20 --autodist_strategy=$AUTODIST_STRATEGY ``` +##### Running BERT on Ray backend +Autodist can be used with Ray with the help of the RaySGD API. To start a ray cluster, first start the head node +``` +ray start --head --port 6379 --include-dashboard=false +``` +and subsequently attach any other nodes to the head node. The training job can then be started by running +``` +python bert_ray.py --input_files=data/*.tfrecord --bert_config_file=bert_config.json +``` +where `data/` has all the pretraining data and `bert_config.json` is the configuration file. This will submit the job to the local Ray cluster (`address='auto'`). Use the `--address` argument if you are targeting a different cluster. The `data/` has to be present on all the nodes of the cluster at the same path. The example supports all other arguments from the base implementation like `--autodist_strategy`. + +Few caveats: During execution on some platforms the TensorFlow servers might complain about too many open files. You can get rid of the errors by setting a higher open file handle limit with `ulimit -n 1064` on all nodes before starting the Ray cluster. +To use a custom CUDA path, export it before starting the Ray cluster processes. #### Neural Collaborative Filtering (NCF) The instruction for generating the training data can be found following [this link](https://github.com/tensorflow/models/tree/master/official/recommendation). diff --git a/examples/benchmark/bert_ray.py b/examples/benchmark/bert_ray.py new file mode 100644 index 0000000..1a5a947 --- /dev/null +++ b/examples/benchmark/bert_ray.py @@ -0,0 +1,220 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# It includes the derived work based on: +# +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import yaml +import os +import sys +import ray +import tensorflow as tf +from absl import app +from absl import flags +from absl import logging + +from utils.logs import logger +from utils.misc import keras_utils + +from utils import bert_modeling as modeling +from utils import bert_models +from utils import common_flags +from utils import input_pipeline +from utils import bert_utils +from utils import ray_utils + +######################################################################### +# Import AutoDist and Strategy +from autodist import AutoDist +from autodist.strategy.all_reduce_strategy import AllReduce +from autodist.strategy.ps_strategy import PS +from autodist.strategy.ps_lb_strategy import PSLoadBalancing +from autodist.strategy.parallax_strategy import Parallax +from autodist.strategy.partitioned_ps_strategy import PartitionedPS +######################################################################### + +flags.DEFINE_string( + 'input_files', + None, + 'File path to retrieve training data for pre-training.') +flags.DEFINE_integer( + 'max_seq_length', 128, + 'The maximum total input sequence length after WordPiece tokenization. ' + 'Sequences longer than this will be truncated, and sequences shorter ' + 'than this will be padded.') +flags.DEFINE_integer('max_predictions_per_seq', 20, + 'Maximum predictions per sequence_output.') +flags.DEFINE_integer('train_batch_size', 2, 'Total batch size for training.') +flags.DEFINE_integer('chunk_size', 256, 'The chunk size for training.') +flags.DEFINE_integer('num_steps_per_epoch', 1000, + 'Total number of training steps to run per epoch.') +flags.DEFINE_string( + name='autodist_strategy', + default='PS', + help='the autodist strategy') +flags.DEFINE_boolean( + name='autodist_patch_tf', + default=True, + help='AUTODIST_PATCH_TF') +flags.DEFINE_string( + 'address', + 'auto', + 'IP address of the Ray head node') + +flags.DEFINE_boolean(name='proxy', default=True, help='turn on off the proxy') + + +common_flags.define_common_bert_flags() + +FLAGS = flags.FLAGS + + +def get_pretrain_dataset_fn(input_file_pattern, seq_length, + max_predictions_per_seq, global_batch_size, + num_replicas_in_sync): + """Returns input dataset from input file string.""" + def _dataset_fn(ctx=None): + """Returns tf.data.Dataset for distributed BERT pretraining.""" + input_patterns = input_file_pattern.split(',') + batch_size = int(global_batch_size / num_replicas_in_sync) + train_dataset = input_pipeline.create_pretrain_dataset( + input_patterns, + seq_length, + max_predictions_per_seq, + batch_size, + is_training=True) + return train_dataset + + return _dataset_fn + + +def get_loss_fn(loss_factor=1.0): + """Returns loss function for BERT pretraining.""" + + def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args): + return tf.keras.backend.mean(losses) * loss_factor + + return _bert_pretrain_loss_fn + + +def run_customized_training(strategy, + bert_config, + max_seq_length, + max_predictions_per_seq, + model_dir, + steps_per_epoch, + steps_per_loop, + epochs, + initial_lr, + input_files, + train_batch_size, + num_replicas): + def _get_pretrain_model(): + """Gets a pretraining model.""" + pretrain_model, core_model = bert_models.pretrain_model( + bert_config, max_seq_length, max_predictions_per_seq) + + pretrain_model.optimizer = tf.optimizers.Adam(lr=initial_lr) + return pretrain_model, core_model + + time_callback = keras_utils.TimeHistory( + train_batch_size * steps_per_loop, 1) + + train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, + max_predictions_per_seq, + train_batch_size, + num_replicas) + + ray_utils.run_ray_job(strategy=strategy, + model_fn=_get_pretrain_model, + loss_fn=get_loss_fn(loss_factor=1.0), + model_dir=model_dir, + train_input_fn=train_input_fn, + steps_per_epoch=steps_per_epoch, + steps_per_loop=steps_per_loop, + epochs=epochs, + sub_model_export_name='pretrained/bert_model', + custom_callbacks=[time_callback]) + + +def run_bert_pretrain(strategy, num_gpus=1, num_nodes=1): + """Runs BERT pre-training.""" + + bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) + logging.info( + 'Training using customized training loop TF 2.0 with AutoDist') + + run_customized_training( + strategy, + bert_config, + FLAGS.max_seq_length, + FLAGS.max_predictions_per_seq, + FLAGS.model_dir, + FLAGS.num_steps_per_epoch, + FLAGS.steps_per_loop, + FLAGS.num_train_epochs, + FLAGS.learning_rate, + FLAGS.input_files, + FLAGS.train_batch_size * num_nodes * num_gpus, + num_nodes * num_gpus) + + +def main(_): + assert tf.version.VERSION.startswith('2.') + + if not FLAGS.model_dir: + FLAGS.model_dir = "/tmp/ckpt/" + + ######################################################################### + # Construct AutoDist with ResourceSpec for Different Strategies + if FLAGS.autodist_patch_tf: + os.environ['AUTODIST_PATCH_TF'] = 'True' + else: + os.environ['AUTODIST_PATCH_TF'] = 'False' + + strategy_table = {'PS': PS(local_proxy_variable=FLAGS.proxy), + 'PSLoadBalancing': PSLoadBalancing(local_proxy_variable=FLAGS.proxy), + 'PartitionedPS': PartitionedPS(local_proxy_variable=FLAGS.proxy), + 'AllReduce': AllReduce(chunk_size=FLAGS.chunk_size), + 'Parallax': Parallax(chunk_size=FLAGS.chunk_size, + local_proxy_variable=FLAGS.proxy)} + + if FLAGS.autodist_strategy not in strategy_table: + raise ValueError( + f"the strategy can be only from {','.join(strategy_table.keys())}") + + logdir = '/tmp/logs' + if not os.path.exists(logdir): + os.makedirs(logdir) + + ray.init(address=FLAGS.address) + num_nodes = len(ray.nodes()) + num_gpus_per_node = max(1, ray.nodes()[0]['Resources'].get('GPU', 0)) + + logname = 'bert_strategy_{}_node_{}_gpu_{}_patch_{}_proxy_{}'.format( + FLAGS.autodist_strategy, num_nodes, num_gpus_per_node, FLAGS.autodist_patch_tf, FLAGS.proxy) + + logging.get_absl_handler().use_absl_log_file(logname, logdir) + + run_bert_pretrain(strategy_table[FLAGS.autodist_strategy], num_gpus_per_node, num_nodes) + + +if __name__ == '__main__': + logging.set_verbosity(logging.INFO) + app.run(main) diff --git a/examples/benchmark/utils/ray_utils.py b/examples/benchmark/utils/ray_utils.py new file mode 100644 index 0000000..acbb5e9 --- /dev/null +++ b/examples/benchmark/utils/ray_utils.py @@ -0,0 +1,102 @@ +# Copyright 2021 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +import os +import time +import ray +import numpy as np +import tensorflow as tf + +from autodist.ray import TFTrainer, TFRunner + +def run_ray_job(strategy, + model_fn, + loss_fn, + model_dir, + train_input_fn, + steps_per_epoch, + steps_per_loop, + epochs, + sub_model_export_name, + custom_callbacks): + + def _get_input_iterator(input_fn, strategy): + """Returns distributed dataset iterator.""" + # When training with TPU pods, datasets needs to be cloned across + # workers. Since Dataset instance cannot be cloned in eager mode, we instead + # pass callable that returns a dataset. + if not callable(input_fn): + raise ValueError( + '`input_fn` should be a closure that returns a dataset.') + if not isinstance(strategy, tf.distribute.Strategy): + iterator = tf.compat.v1.data.make_one_shot_iterator(input_fn()) + else: + iterator = iter( + strategy.experimental_distribute_datasets_from_function(input_fn)) + return iterator + + def _replicated_step(model, core_model, inputs): + """Replicated training step.""" + optimizer = model.optimizer + use_float16 = isinstance( + optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer) + + inputs, labels = inputs + model_outputs = model(inputs, training=True) + loss = loss_fn(labels, model_outputs) + if use_float16: + scaled_loss = optimizer.get_scaled_loss(loss) + + training_vars = model.trainable_variables + if use_float16: + scaled_grads = tf.gradients(scaled_loss, training_vars) + grads = optimizer.get_unscaled_gradients(scaled_grads) + else: + grads = tf.gradients(loss, training_vars) + train_op = optimizer.apply_gradients(zip(grads, training_vars)) + return train_op, loss + + def input_fn(): + return _get_input_iterator(train_input_fn, strategy) + + def _run_callbacks_on_batch_begin(batch): + """Runs custom callbacks at the start of every step.""" + if not custom_callbacks: + return + for callback in custom_callbacks: + callback.on_batch_begin(batch) + + def _run_callbacks_on_batch_end(batch): + """Runs custom callbacks at the end of every step.""" + if not custom_callbacks: + return + for callback in custom_callbacks: + callback.on_batch_end(batch) + + trainer = TFTrainer(strategy, _replicated_step, model_fn, input_fn) + + for epoch in range(epochs): + _run_callbacks_on_batch_begin(epoch) + per_replica = trainer.train() + _run_callbacks_on_batch_end(epoch) + avg_loss = sum(val[1] for val in per_replica.values()) / len(per_replica) + print(f"Avg loss: {avg_loss}") + + trainer.save("/tmp/ckpt/", checkpoint_prefix="bert") + + trainer.shutdown() + + diff --git a/examples/linear_regression_ray.py b/examples/linear_regression_ray.py new file mode 100644 index 0000000..bc6773e --- /dev/null +++ b/examples/linear_regression_ray.py @@ -0,0 +1,92 @@ +# Copyright 2021 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +import os +import time +import ray +import numpy as np +import tensorflow as tf + +from autodist.strategy import PS, PSLoadBalancing, PartitionedPS, AllReduce, Parallax +from autodist.ray import TFTrainer, TFRunner + +ray.init(address='auto') + +EPOCHS = 10 + +def data_creator(): + TRUE_W = 3.0 + TRUE_b = 2.0 + NUM_EXAMPLES = 1000 + + inputs = np.random.randn(NUM_EXAMPLES) + noises = np.random.randn(NUM_EXAMPLES) + outputs = inputs * TRUE_W + TRUE_b + noises + + class MyIterator: + def __init__(self, data): + self.data = data + def initialize(self): + return tf.zeros(1) + def get_next(self): + # a fake one + return self.data + + return MyIterator(inputs), outputs + + +class Model: + def __init__(self): + self.W = tf.Variable(5.0, name='W', dtype=tf.float64) + self.b = tf.Variable(0.0, name='b', dtype=tf.float64) + + def __call__(self, x): + return self.W * x + self.b + + +def train_step(model, inputs, outputs): + def l(predicted_y, desired_y): + return tf.reduce_mean(tf.square(predicted_y - desired_y)) + + major_version, _, _ = tf.version.VERSION.split('.') + if major_version == '1': + optimizer = tf.train.GradientDescentOptimizer(0.01) + else: + optimizer = tf.optimizers.SGD(0.01) + + loss = l(model(inputs), outputs) + vs = [model.W, model.b] + + gradients = tf.gradients(loss, vs) + + train_op = optimizer.apply_gradients(zip(gradients, vs)) + return loss, train_op, model.b + +def model_creator(): + return Model() + +def main(_): + trainer = TFTrainer(PS(), train_step, model_creator, data_creator) + for epoch in range(EPOCHS): + per_replica = trainer.train() + for host, output in per_replica.items(): + l, _, b = output + print(f"node:{host}\tloss: {l}\tb:{b}") + + trainer.shutdown() + +main(sys.argv) + diff --git a/tests/integration/cases/c10.py b/tests/integration/cases/c10.py index 23ac55d..278c5ce 100644 --- a/tests/integration/cases/c10.py +++ b/tests/integration/cases/c10.py @@ -77,7 +77,7 @@ def l(predicted_y, desired_y): # Only save the model on master node if autodist is used with NFS. checkpoint_suffix = 'c10' checkpoint_name = checkpoint_dir + checkpoint_suffix - if IS_AUTODIST_CHIEF: + if IS_AUTODIST_CHIEF(): saver.save(session, checkpoint_name, global_step=epoch) print('Checkpoint saved at {%s}' % checkpoint_name) else: @@ -85,7 +85,7 @@ def l(predicted_y, desired_y): # check the checkpoint existence only on master node checkpoint = checkpoint_name + '-' + str(epoch) - if IS_AUTODIST_CHIEF: + if IS_AUTODIST_CHIEF(): assert(os.path.exists(checkpoint + '.meta')) # meta file assert(os.path.exists(checkpoint + '.index')) # meta file assert(os.path.exists(checkpoint + '.data-00000-of-00001')) # meta file