Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodist on Ray using RaySGD API #61

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion autodist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
logging.set_verbosity(ENV.AUTODIST_MIN_LOG_LEVEL.val)

# Enforce abspath
if sys.argv and os.path.exists(sys.argv[0]) and not os.path.isabs(sys.argv[0]):
if sys.argv and os.path.exists(sys.argv[0]) and not os.path.isabs(sys.argv[0]) \
and "ray" not in sys.modules:
logging.error('AutoDist requires the script path "{}" to be an absolute path to be shared across workers. '
'Now exit.'.format(sys.argv[0]))
sys.exit(1)
Expand Down
27 changes: 21 additions & 6 deletions autodist/autodist.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,20 @@
from autodist.strategy.ps_lb_strategy import PSLoadBalancing
from autodist.utils import logging

IS_AUTODIST_WORKER = bool(ENV.AUTODIST_WORKER.val)
IS_AUTODIST_CHIEF = not IS_AUTODIST_WORKER

_DEFAULT_AUTODIST = {}


def IS_AUTODIST_WORKER(): # noqa
"""True if current worker is just a worker."""
return bool(ENV.AUTODIST_WORKER.val)


def IS_AUTODIST_CHIEF(): # noqa
"""True if current worker is the Chief."""
return not IS_AUTODIST_WORKER()


def set_default_autodist(o):
"""Set the AutoDist object the scope of which you are in."""
global _DEFAULT_AUTODIST
Expand All @@ -64,9 +72,12 @@ class _AutoDistInterface:
Ancestor of _V1Graph, _V2Graph, and _V2Eager -- the different ways to run TF code.
"""

def __init__(self, resource_spec_file, strategy_builder=None):
def __init__(self, resource_spec_file=None, strategy_builder=None, resource_spec=None, strategy=None):
set_default_autodist(self)
self._resource_spec = ResourceSpec(resource_file=resource_spec_file)
if resource_spec_file is not None:
self._resource_spec = ResourceSpec(resource_file=resource_spec_file)
else:
self._resource_spec = resource_spec
self._strategy_builder = strategy_builder or PSLoadBalancing()

self._original_graph_item = None
Expand All @@ -76,6 +87,7 @@ def __init__(self, resource_spec_file, strategy_builder=None):

self._cluster: Cluster = SSHCluster(self._resource_spec) # which can be also defined with strategy
self._coordinator: Coordinator
self._strategy = strategy # Directly passed strategy to Ray workers if not None

@tf_contextlib.contextmanager
def _scope(self):
Expand All @@ -99,9 +111,11 @@ def build_strategy(self):

def _build_or_load_strategy(self):
self._original_graph_item.prepare()
if IS_AUTODIST_CHIEF:
if IS_AUTODIST_CHIEF():
s = self.build_strategy()
s.serialize()
elif self._strategy is not None:
s = self._strategy
else:
strategy_id = ENV.AUTODIST_STRATEGY_ID.val
assert strategy_id
Expand All @@ -119,7 +133,7 @@ def _compile_strategy(self, strategy):

def _setup(self, strategy):
"""Prepare for the execution."""
if IS_AUTODIST_CHIEF:
if IS_AUTODIST_CHIEF() and not ENV.AUTODIST_RAY_BACKEND.val:
# we should only have one single coordinator for one single AutoDist() instance scope,
# even though we could have multiple strategies.
self._coordinator = Coordinator(strategy=strategy, cluster=self._cluster)
Expand Down Expand Up @@ -148,6 +162,7 @@ def _build(self):
self._transformed_graph_item = graph_transformer.transform()
self._remapper = Remapper(graph_transformer, self._transformed_graph_item)
self._built = self._original_graph_item.graph.as_graph_def()
self._strategy = strategy

def is_built(self):
"""
Expand Down
1 change: 1 addition & 0 deletions autodist/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ENV(Enum):
AUTODIST_INTERNAL_TF = auto(), lambda v: (v or "False") == "True" # noqa: E731
SYS_DATA_PATH = auto(), lambda v: v or "" # noqa: E731
SYS_RESOURCE_PATH = auto(), lambda v: v or "" # noqa: E731
AUTODIST_RAY_BACKEND = auto(), lambda v: True if v == "True" else False # noqa: E731

@property
def val(self):
Expand Down
16 changes: 16 additions & 0 deletions autodist/ray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2021 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .backend import TFTrainer
from .backend import TFRunner
241 changes: 241 additions & 0 deletions autodist/ray/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright 2021 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Autodist Ray Backend, includes TFRunner and TFTrainer implementations."""
import os
import tensorflow as tf
import tensorflow.compat.v1 as v1
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.training.server_lib import ClusterSpec, Server
import ray

from autodist import AutoDist
from autodist.const import ENV, DEFAULT_GROUP_LEADER
from autodist.resource_spec import ResourceSpec
from autodist.resource_spec import DeviceSpec
from autodist.cluster import Cluster
from autodist.checkpoint.saver import Saver as autodist_saver


@ray.remote
class TFServer:
"""Tensorflow Server Actor responsible for executing the actual ops."""

@staticmethod
def launch(cluster_spec, job_name, task_index, num_cpu_device):
"""Start the TF server. This call blocks."""
experimental = config_pb2.ConfigProto.Experimental(
collective_nccl=True,
collective_group_leader=DEFAULT_GROUP_LEADER)
s = Server(
ClusterSpec(cluster_spec),
job_name=job_name,
task_index=task_index,
config=config_pb2.ConfigProto(
experimental=experimental,
device_count={"CPU": num_cpu_device},
inter_op_parallelism_threads=0,
intra_op_parallelism_threads=0,
)
)
s.join()


class TFRunner:
"""Each TFRunner including master represents one replica of the training job."""

def __init__(self, # pylint: disable=too-many-arguments
strategy_builder,
strategy,
train_step,
model_fn,
input_fn,
env,
resource_spec):
self._epoch = 0
# Setup environment vars for the new runner
for var, val in env.items():
if isinstance(val, bool):
os.environ[var] = "True" if val else "False"
else:
os.environ[var] = val

# Set Ray backend to True
os.environ[ENV.AUTODIST_RAY_BACKEND.name] = "True"

# We either pass a strategy_builder or directly a strategy
self._autodist = AutoDist(strategy_builder=strategy_builder,
strategy=strategy,
resource_spec=resource_spec)
self._g = v1.Graph()
with self._g.as_default(), self._autodist.scope():
# model_fn and input_fn can return multiple things, pack and
# unpack them into the step function
models = model_fn()
inputs = input_fn()
if isinstance(inputs, tuple):
iterators = (i.get_next() if hasattr(i, 'get_next')
else i for i in inputs)
else:
iterators = (inputs.get_next() if hasattr(inputs, 'get_next')
else inputs,)

if not isinstance(models, tuple):
models = (models,)

# Create saver before creating the session
self._saver = autodist_saver()
self._fetches = train_step(*models, *iterators)
self._session = self._autodist.create_distributed_session()

def step(self):
"""Take one training step."""
self._epoch += 1
return self._session.run(self._fetches)

def get_strategy(self):
"""Fetch the current strategy."""
return self._autodist._strategy

def save(self, checkpoint_dir, checkpoint_prefix=""):
"""Save a TF checkpoint."""
self._saver.save(self._session, checkpoint_dir + checkpoint_prefix, global_step=self._epoch + 1)
self._saver.restore(self._session, tf.train.latest_checkpoint(checkpoint_dir))

def restore(self, checkpoint_dir):
"""Restore the checkpoint from the directory."""
with self._g.as_default(), self._autodist.scope():
self._saver.restore(self._session, tf.train.latest_checkpoint(checkpoint_dir))


class TFTrainer:
"""TFTrainer represents one training job."""

def __init__(self, strategy_builder, train_step, model_fn, input_fn):

# Set Ray backend
os.environ[ENV.AUTODIST_RAY_BACKEND.name] = "True"

# Go from resource_info -> ResourceSpec -> ClusterSpec
self._resource_spec = ResourceSpec(
resource_info=self._get_resource_info())

self._replicas = [] # Replica actors, also contains master

# Start TF Servers on each node of the cluster
self._start_tf_servers()

def spawn_replica(replica_host, strategy_builder, strategy=None, env=None):
# Enforce actor placement on the provided host
runner = ray.remote(resources={f"node:{replica_host}": 0.01},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this requires custom resource specification when you do ray up to start the ray cluster?

num_cpus=1)(TFRunner)
return runner.remote(strategy_builder,
strategy,
train_step,
model_fn,
input_fn,
env if env is not None else {},
self._resource_spec)

# Start the master worker, let it build a strategy from the strategy builder
self._master = spawn_replica(ray._private.services.get_node_ip_address(), strategy_builder)

# Add master to replicas list because it also acts as one of the clients
self._replicas.append((ray._private.services.get_node_ip_address(), self._master))

# Fetch the strategy directly from the master
strategy = ray.get(self._master.get_strategy.remote())

assert strategy is not None

# Spawn clients based on the strategy built by master
replica_devices = [
DeviceSpec.from_string(device_string)
for device_string in strategy.graph_config.replicas
]

replica_hosts = {d.host_address for d in replica_devices}
for replica_host in replica_hosts:
if replica_host != ray._private.services.get_node_ip_address():
# Only non-master replicas
env = {
ENV.AUTODIST_WORKER.name: replica_host,
ENV.AUTODIST_MIN_LOG_LEVEL.name: ENV.AUTODIST_MIN_LOG_LEVEL.val,
ENV.AUTODIST_IS_TESTING.name: ENV.AUTODIST_IS_TESTING.val,
ENV.AUTODIST_PATCH_TF.name: ENV.AUTODIST_PATCH_TF.val,
ENV.AUTODIST_INTERNAL_TF.name: ENV.AUTODIST_INTERNAL_TF.val,
ENV.SYS_DATA_PATH.name: ENV.SYS_DATA_PATH.val,
ENV.SYS_RESOURCE_PATH.name: ENV.SYS_RESOURCE_PATH.val,
}
self._replicas.append((replica_host, spawn_replica(replica_host, None, strategy, env)))

def _start_tf_servers(self):
"""Launch TF server actors on each Ray nodes."""
cluster_spec = Cluster._get_default_cluster_spec(self._resource_spec)
cpu_devices = Cluster._get_node_cpu_devices(self._resource_spec)
gpu_devices = Cluster._get_node_gpu_devices(self._resource_spec)

self._servers = []
for job_name, tasks in cluster_spec.items():
for task_index, full_address in enumerate(tasks):
node_ip, _ = full_address.split(':')
# Make sure we spawn one server per Ray node
# Give it all the GPUs on that node
server = TFServer.options(resources={f"node:{node_ip}": 0.01},
num_gpus=len(gpu_devices.get(node_ip, []))).remote()
self._servers.append(server)
server.launch.remote(cluster_spec,
job_name,
task_index,
len(cpu_devices[node_ip]))

@staticmethod
def _get_resource_info():
"""Create resource_info from resources available to the Ray cluster."""
resource_info = {}
resource_info["nodes"] = []
chief_address = ray._private.services.get_node_ip_address()
for node in ray.nodes():
node_ip = node["NodeManagerAddress"]
cpu_count = node["Resources"].get("CPU")
gpu_count = node["Resources"].get("GPU")
if not node["Alive"] or (cpu_count is None and gpu_count is None):
continue
node = {"address": node_ip,
"cpus": [0] if cpu_count else [],
"gpus": list(range(int(gpu_count))) if gpu_count else []}
if node_ip == chief_address:
node["chief"] = True
resource_info["nodes"].append(node)
return resource_info

def train(self):
"""Runs one training epoch."""
return dict(zip([replica[0] for replica in self._replicas],
ray.get([replica[1].step.remote() for replica in self._replicas])))

def save(self, checkpoint_dir, checkpoint_prefix):
"""Save a checkpoint with prefix."""
ray.get(self._master.save.remote(checkpoint_dir, checkpoint_prefix))

def restore(self, checkpoint_dir):
"""Restore the latest checkpoint from directory."""
ray.get(self._master.restore.remote(checkpoint_dir))

def shutdown(self):
"""Shutdown all the actors and the training job."""
for server in self._servers:
ray.kill(server)
for replica in self._replicas:
ray.kill(replica[1])
Loading