diff --git a/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py b/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py index bbbbc0f23..ad6a96596 100644 --- a/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py +++ b/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py @@ -74,8 +74,8 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind return server.target, device -def create_monitored_session(target: tf.train.Server, task_index: int, - checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session: +def create_monitored_session(target: tf.train.Server, task_index: int, checkpoint_dir: str, checkpoint_save_secs: int, + scaffold: tf.train.Scaffold, config: tf.ConfigProto=None) -> tf.Session: """ Create a monitored session for the worker :param target: the target string for the tf.Session diff --git a/rl_coach/architectures/tensorflow_components/savers.py b/rl_coach/architectures/tensorflow_components/savers.py index 67c0c8b67..ad71f380c 100644 --- a/rl_coach/architectures/tensorflow_components/savers.py +++ b/rl_coach/architectures/tensorflow_components/savers.py @@ -28,7 +28,7 @@ def __init__(self, name): # if graph is finalized, savers must have already already been added. This happens # in the case of a MonitoredSession self._variables = tf.global_variables() - + # target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list # the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow. self._variables = [v for v in self._variables if '/target' not in v.name] diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index 3c03de834..414940a37 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -583,8 +583,8 @@ def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_on class DistributedTaskParameters(TaskParameters): - def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str, - task_index: int, evaluate_only: int=None, num_tasks: int=None, + def __init__(self, framework_type: Frameworks=None, parameters_server_hosts: str=None, worker_hosts: str=None, + job_type: str=None, task_index: int=None, evaluate_only: int=None, num_tasks: int=None, num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None, shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False): diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 304cf83e2..f0eb08239 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -34,7 +34,8 @@ from multiprocessing.managers import BaseManager import subprocess from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager -from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir +from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, \ + get_base_dir, start_multi_threaded_learning from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.environments.environment import SingleLevelSelection from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters @@ -87,6 +88,17 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters' def handle_distributed_coach_tasks(graph_manager, args, task_parameters): ckpt_inside_container = "/checkpoint" + non_dist_task_parameters = TaskParameters( + framework_type=args.framework, + evaluate_only=args.evaluate, + experiment_path=args.experiment_path, + seed=args.seed, + use_cpu=args.use_cpu, + checkpoint_save_secs=args.checkpoint_save_secs, + checkpoint_save_dir=args.checkpoint_save_dir, + export_onnx_graph=args.export_onnx_graph, + apply_stop_condition=args.apply_stop_condition + ) memory_backend_params = None if args.memory_backend_params: @@ -102,15 +114,18 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters): graph_manager.data_store_params = data_store_params if args.distributed_coach_run_type == RunType.TRAINER: + if not args.distributed_training: + task_parameters = non_dist_task_parameters task_parameters.checkpoint_save_dir = ckpt_inside_container training_worker( graph_manager=graph_manager, task_parameters=task_parameters, + args=args, is_multi_node_test=args.is_multi_node_test ) if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER: - task_parameters.checkpoint_restore_dir = ckpt_inside_container + non_dist_task_parameters.checkpoint_restore_dir = ckpt_inside_container data_store = None if args.data_store_params: @@ -120,7 +135,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters): graph_manager=graph_manager, data_store=data_store, num_workers=args.num_workers, - task_parameters=task_parameters + task_parameters=non_dist_task_parameters ) @@ -552,6 +567,11 @@ def get_argument_parser(self) -> argparse.ArgumentParser: parser.add_argument('-dc', '--distributed_coach', help="(flag) Use distributed Coach.", action='store_true') + parser.add_argument('-dt', '--distributed_training', + help="(flag) Use distributed training with Coach." + "Used only with --distributed_coach flag." + "Ignored if --distributed_coach flag is not used.", + action='store_true') parser.add_argument('-dcp', '--distributed_coach_config_path', help="(string) Path to config file when using distributed rollout workers." "Only distributed Coach parameters should be provided through this config file." @@ -607,18 +627,31 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp atexit.register(logger.summarize_experiment) screen.change_terminal_title(args.experiment_name) - task_parameters = TaskParameters( - framework_type=args.framework, - evaluate_only=args.evaluate, - experiment_path=args.experiment_path, - seed=args.seed, - use_cpu=args.use_cpu, - checkpoint_save_secs=args.checkpoint_save_secs, - checkpoint_restore_dir=args.checkpoint_restore_dir, - checkpoint_save_dir=args.checkpoint_save_dir, - export_onnx_graph=args.export_onnx_graph, - apply_stop_condition=args.apply_stop_condition - ) + if args.num_workers == 1: + task_parameters = TaskParameters( + framework_type=args.framework, + evaluate_only=args.evaluate, + experiment_path=args.experiment_path, + seed=args.seed, + use_cpu=args.use_cpu, + checkpoint_save_secs=args.checkpoint_save_secs, + checkpoint_restore_dir=args.checkpoint_restore_dir, + checkpoint_save_dir=args.checkpoint_save_dir, + export_onnx_graph=args.export_onnx_graph, + apply_stop_condition=args.apply_stop_condition + ) + else: + task_parameters = DistributedTaskParameters( + framework_type=args.framework, + use_cpu=args.use_cpu, + num_training_tasks=args.num_workers, + experiment_path=args.experiment_path, + checkpoint_save_secs=args.checkpoint_save_secs, + checkpoint_restore_dir=args.checkpoint_restore_dir, + checkpoint_save_dir=args.checkpoint_save_dir, + export_onnx_graph=args.export_onnx_graph, + apply_stop_condition=args.apply_stop_condition + ) # open dashboard if args.open_dashboard: @@ -633,78 +666,16 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp # Single-threaded runs if args.num_workers == 1: - self.start_single_threaded(task_parameters, graph_manager, args) + self.start_single_threaded_learning(task_parameters, graph_manager, args) else: - self.start_multi_threaded(graph_manager, args) + global start_graph + start_multi_threaded_learning(start_graph, (graph_manager, task_parameters), + task_parameters, graph_manager, args) - def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace): + def start_single_threaded_learning(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace): # Start the training or evaluation start_graph(graph_manager=graph_manager, task_parameters=task_parameters) - def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace): - total_tasks = args.num_workers - if args.evaluation_worker: - total_tasks += 1 - - ps_hosts = "localhost:{}".format(get_open_port()) - worker_hosts = ",".join(["localhost:{}".format(get_open_port()) for i in range(total_tasks)]) - - # Shared memory - class CommManager(BaseManager): - pass - CommManager.register('SharedMemoryScratchPad', SharedMemoryScratchPad, exposed=['add', 'get', 'internal_call']) - comm_manager = CommManager() - comm_manager.start() - shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad() - - def start_distributed_task(job_type, task_index, evaluation_worker=False, - shared_memory_scratchpad=shared_memory_scratchpad): - task_parameters = DistributedTaskParameters( - framework_type=args.framework, - parameters_server_hosts=ps_hosts, - worker_hosts=worker_hosts, - job_type=job_type, - task_index=task_index, - evaluate_only=0 if evaluation_worker else None, # 0 value for evaluation worker as it should run infinitely - use_cpu=args.use_cpu, - num_tasks=total_tasks, # training tasks + 1 evaluation task - num_training_tasks=args.num_workers, - experiment_path=args.experiment_path, - shared_memory_scratchpad=shared_memory_scratchpad, - seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed - checkpoint_save_secs=args.checkpoint_save_secs, - checkpoint_restore_dir=args.checkpoint_restore_dir, - checkpoint_save_dir=args.checkpoint_save_dir, - export_onnx_graph=args.export_onnx_graph, - apply_stop_condition=args.apply_stop_condition - ) - # we assume that only the evaluation workers are rendering - graph_manager.visualization_parameters.render = args.render and evaluation_worker - p = Process(target=start_graph, args=(graph_manager, task_parameters)) - # p.daemon = True - p.start() - return p - - # parameter server - parameter_server = start_distributed_task("ps", 0) - - # training workers - # wait a bit before spawning the non chief workers in order to make sure the session is already created - workers = [] - workers.append(start_distributed_task("worker", 0)) - time.sleep(2) - for task_index in range(1, args.num_workers): - workers.append(start_distributed_task("worker", task_index)) - - # evaluation worker - if args.evaluation_worker or args.render: - evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True) - - # wait for all workers - [w.join() for w in workers] - if args.evaluation_worker: - evaluation_worker.terminate() - def main(): launcher = CoachLauncher() diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 589ee5fda..24a891cf7 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -88,16 +88,20 @@ def save_to_store(self): # Acquire lock self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0) + ckpt_state_filename = CheckpointStateFile.checkpoint_state_filename state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) if state_file.exists(): ckpt_state = state_file.read() + ckpt_name_prefix = ckpt_state.name + + if ckpt_state_filename is not None and ckpt_name_prefix is not None: checkpoint_file = None for root, dirs, files in os.walk(self.params.checkpoint_dir): for filename in files: - if filename == CheckpointStateFile.checkpoint_state_filename: + if filename == ckpt_state_filename: checkpoint_file = (root, filename) continue - if filename.startswith(ckpt_state.name): + if filename.startswith(ckpt_name_prefix): abs_name = os.path.abspath(os.path.join(root, filename)) rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) @@ -131,6 +135,8 @@ def load_from_store(self): """ try: state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) + ckpt_state_filename = state_file.filename + ckpt_state_file_path = state_file.path # wait until lock is removed while True: @@ -139,7 +145,7 @@ def load_from_store(self): if next(objects, None) is None: try: # fetch checkpoint state file from S3 - self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path) + self.mc.fget_object(self.params.bucket_name, ckpt_state_filename, ckpt_state_file_path) except Exception as e: continue break @@ -156,10 +162,12 @@ def load_from_store(self): ) except Exception as e: pass + state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) + ckpt_state = state_file.read() + ckpt_name_prefix = ckpt_state.name - checkpoint_state = state_file.read() - if checkpoint_state is not None: - objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True) + if ckpt_name_prefix is not None: + objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=ckpt_name_prefix, recursive=True) for obj in objects: filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) if not os.path.exists(filename): diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 0a7b25087..8793054e6 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -226,11 +226,15 @@ def _create_session_tf(self, task_parameters: TaskParameters): else: checkpoint_dir = task_parameters.checkpoint_save_dir + self.checkpoint_saver = tf.train.Saver() + scaffold = tf.train.Scaffold(saver=self.checkpoint_saver) + self.sess = create_monitored_session(target=task_parameters.worker_target, task_index=task_parameters.task_index, checkpoint_dir=checkpoint_dir, checkpoint_save_secs=task_parameters.checkpoint_save_secs, - config=config) + config=config, + scaffold=scaffold) # set the session for all the modules self.set_session(self.sess) else: @@ -258,9 +262,11 @@ def create_session(self, task_parameters: TaskParameters): raise ValueError('Invalid framework {}'.format(task_parameters.framework_type)) # Create parameter saver - self.checkpoint_saver = SaverCollection() - for level in self.level_managers: - self.checkpoint_saver.update(level.collect_savers()) + if not isinstance(task_parameters, DistributedTaskParameters): + self.checkpoint_saver = SaverCollection() + for level in self.level_managers: + self.checkpoint_saver.update(level.collect_savers()) + # restore from checkpoint if given self.restore_checkpoint() @@ -540,8 +546,9 @@ def improve(self): count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps while self.total_steps_counters[RunPhase.TRAIN] < count_end: self.train_and_act(self.steps_between_evaluation_periods) - if self.evaluate(self.evaluation_steps): - break + if self.task_parameters.task_index == 0 or self.task_parameters.task_index is None: + if self.evaluate(self.evaluation_steps): + break def restore_checkpoint(self): self.verify_graph_was_created() @@ -599,7 +606,9 @@ def save_checkpoint(self): if not isinstance(self.task_parameters, DistributedTaskParameters): saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) else: - saved_checkpoint_path = checkpoint_path + # FIXME: Explicitly managing Saver inside monitored training session is not recommended. + # https://github.com/tensorflow/tensorflow/issues/8425#issuecomment-286927528. + saved_checkpoint_path = self.checkpoint_saver.save(self.sess._tf_sess(), checkpoint_path) # this is required in order for agents to save additional information like a DND for example [manager.save_checkpoint(checkpoint_name) for manager in self.level_managers] diff --git a/rl_coach/tests/test_dist_coach.py b/rl_coach/tests/test_dist_coach.py index 8bad5a083..e943ef0dc 100644 --- a/rl_coach/tests/test_dist_coach.py +++ b/rl_coach/tests/test_dist_coach.py @@ -73,7 +73,9 @@ def get_tests(): """ tests = [ 'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1', - 'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1' + 'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1', + 'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2', + 'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2' ] return tests diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index 0ada24908..fcc60e857 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -22,6 +22,7 @@ from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType from rl_coach import core_types from rl_coach.logger import screen +from rl_coach.utils import start_multi_threaded_learning def data_store_ckpt_save(data_store): @@ -30,7 +31,15 @@ def data_store_ckpt_save(data_store): time.sleep(10) -def training_worker(graph_manager, task_parameters, is_multi_node_test): +def training_worker(graph_manager, task_parameters, args, is_multi_node_test): + if args.distributed_training: + start_multi_threaded_learning(train, (graph_manager, task_parameters, is_multi_node_test), + task_parameters, graph_manager, args) + else: + train(graph_manager, task_parameters, is_multi_node_test) + + +def train(graph_manager, task_parameters, is_multi_node_test): """ restore a checkpoint then perform rollouts using the restored model :param graph_manager: An instance of the graph manager @@ -40,8 +49,9 @@ def training_worker(graph_manager, task_parameters, is_multi_node_test): # initialize graph graph_manager.create_graph(task_parameters) - # save randomly initialized graph - graph_manager.save_checkpoint() + # save randomly initialized graph using one trainer + if task_parameters.task_index == 0 or task_parameters.task_index is None: + graph_manager.save_checkpoint() # training loop steps = 0 @@ -71,10 +81,12 @@ def training_worker(graph_manager, task_parameters, is_multi_node_test): if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset: eval_offset += 1 - if graph_manager.evaluate(graph_manager.evaluation_steps): - break + if task_parameters.task_index == 0 or task_parameters.task_index is None: + if graph_manager.evaluate(graph_manager.evaluation_steps): + break if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: - graph_manager.save_checkpoint() + if task_parameters.task_index == 0 or task_parameters.task_index is None: + graph_manager.save_checkpoint() else: graph_manager.occasionally_save_checkpoint() diff --git a/rl_coach/utils.py b/rl_coach/utils.py index 452e5d0fd..ffb716fb5 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -14,6 +14,7 @@ # limitations under the License. # +import argparse import importlib import importlib.util import inspect @@ -24,7 +25,8 @@ import threading import time import traceback -from multiprocessing import Manager +from multiprocessing import Manager, Process +from multiprocessing.managers import BaseManager from subprocess import Popen from typing import List, Tuple, Union @@ -532,3 +534,58 @@ def cleanup(): def indent_string(string): return '\t' + string.replace('\n', '\n\t') + + +def start_multi_threaded_learning(target_func, target_func_args, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace): + total_tasks = args.num_workers + if args.evaluation_worker: + total_tasks += 1 + + ps_hosts = "localhost:{}".format(get_open_port()) + worker_hosts = ",".join(["localhost:{}".format(get_open_port()) for i in range(total_tasks)]) + + # Shared memory + class CommManager(BaseManager): + pass + CommManager.register('SharedMemoryScratchPad', SharedMemoryScratchPad, exposed=['add', 'get', 'internal_call']) + comm_manager = CommManager() + comm_manager.start() + shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad() + + def start_distributed_task(job_type, task_index, evaluation_worker=False, + shared_memory_scratchpad=shared_memory_scratchpad): + task_parameters.parameters_server_hosts = ps_hosts + task_parameters.worker_hosts = worker_hosts + task_parameters.job_type = job_type + task_parameters.task_index = task_index + task_parameters.evaluate_only = 0 if evaluation_worker else None # 0 value for evaluation worker as it should run infinitely + task_parameters.num_tasks = total_tasks # training tasks + 1 evaluation task + task_parameters.shared_memory_scratchpad = shared_memory_scratchpad + task_parameters.seed = args.seed+task_index if args.seed is not None else None # each worker gets a different seed + + # we assume that only the evaluation workers are rendering + graph_manager.visualization_parameters.render = args.render and evaluation_worker + p = Process(target=target_func, args=target_func_args) + # p.daemon = True + p.start() + return p + + # parameter server + parameter_server = start_distributed_task("ps", 0) + + # training workers + # wait a bit before spawning the non chief workers in order to make sure the session is already created + workers = [] + workers.append(start_distributed_task("worker", 0)) + time.sleep(2) + for task_index in range(1, args.num_workers): + workers.append(start_distributed_task("worker", task_index)) + + # evaluation worker + if args.evaluation_worker or args.render: + evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True) + + # wait for all workers + [w.join() for w in workers] + if args.evaluation_worker: + evaluation_worker.terminate()