Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve graph transformation performance #35 #43

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions autodist/graph_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import copy
import functools
from collections import defaultdict
from typing import Union, Callable

from google.protobuf.any_pb2 import Any
Expand Down Expand Up @@ -242,6 +243,37 @@ def __init__(self, graph: ops.Graph = None, graph_def: GraphDef = None):
self.info = Info()
self.optimizer, self.optimizer_args, self.optimizer_kwargs = None, None, None


# Optimizing the var_op_name_to_grad query.
# used to inform the var_op_name_to_grad_dict that the graph has been modified
# only used when the synchronizer is calling the lookup with optimize=True
self.updated = True
# used to cached the result of var_op_name_to_grad function from last time
self.var_op_name_to_grad_dict = dict()
# map the updated op to its inputs variables, used to optimize var_op_name_to_grad
self.update_op_depend_var = defaultdict(list)

# on if this graph is in loop optimize mode for the first time
self.first_time_loop = True
self.loop_phase = False
self.var_quried = []
self.useful_update_op = []

# how many local replica is this graph comprised of
self.num_replica = 0
self.var_op_appear_time = defaultdict(int)


def start_loop_optimize(self):
"""start a loop of synchronizer apply"""
self.first_time_loop = True
self.loop_phase = True

def end_loop_optimize(self):
"""end a loop of synchronizer apply"""
self.first_time_loop = True
self.loop_phase = False

def get_trainable_variables(self):
"""Get variables that need to be synchronized if doing data parallelism."""
return [op.outputs[0] for op in self.trainable_var_op_to_var]
Expand Down Expand Up @@ -319,6 +351,10 @@ def all_update_ops(self):
@property
def var_op_name_to_grad_info(self):
"""A mapping from VarHandleOp name (e.g. "W" not "W:0") to its (grad, var, update_op) tuple."""
# this method only called when the caller does not know there is an optimization for this.
# so if it is in loop phase, we compute the dict again.
if (not self.updated and not self.loop_phase):
return self.var_op_name_to_grad_dict
expected_var_ops = {var.op: (grad, var) for grad, var in self.grad_target_pairs.items()}
res = {}
for op in self.all_update_ops:
Expand All @@ -336,8 +372,55 @@ def var_op_name_to_grad_info(self):
if var_op.name in res:
raise ValueError('A variable cannot correspond to more than one update op for now.')
res[var_op.name] = expected_var_ops[var_op] + (op,)
self.updated = False
self.var_op_name_to_grad_dict = res
return res

@property
def var_op_name_to_grad_info_optimize(self):
"""A mapping from VarHandleOp name (e.g. "W" not "W:0") to its (grad, var, update_op) tuple.
An optimized version that is aware of this method is iteratively used"""
# if the graph has not been rewritten, return old dict instead of generating a new one
if not self.updated:
return self.var_op_name_to_grad_dict
expected_var_ops = {var.op: (grad, var) for grad, var in self.grad_target_pairs.items()}
res = []
# keep a list of useful update_op
if self.first_time_loop:
self.useful_update_op = self.all_update_ops.copy()
for op in self.useful_update_op:
var_op = op.inputs[op_info.UPDATE_OP_VAR_POS].op
on_trainable_variable = var_op in expected_var_ops
var_scope = var_op.name
update_op_scope = parse_name_scope(op.name)
is_initialization = update_op_scope == var_scope
# TODO: we should not hardcode this scope.
# It is actually coming from the name given to the saver
is_saving = update_op_scope.endswith('save')
# TODO(future): support one variable -> multiple update ops (see AdamWeightDecay optimizer)
if on_trainable_variable and not is_initialization and not is_saving and not self._is_auxiliary(op):
if var_op.name in res:
raise ValueError('A variable cannot correspond to more than one update op for now.')
res.append(var_op.name)
self.var_op_name_to_grad_dict[var_op.name] = expected_var_ops[var_op] + (op,)
#analyze what var_ops the op depends on, if all removed, then can remove this op from the loop
if self.first_time_loop:
self.update_op_depend_var[op].append(var_op.name)

assert len(self.var_quried) <= 1
if len(self.var_quried) > 0:
if var_op.name == self.var_quried[0]:
self.var_op_appear_time[var_op] += 1
self.var_quried.remove(var_op.name)
self.useful_update_op.remove(op)

# recalculated the dict, set the indicator
self.updated = False
self.first_time_loop = False
#print(self.var_op_name_to_grad_dict["AutoDist-Replica-0/word_embeddings/embeddings"])
return self.var_op_name_to_grad_dict


def _is_auxiliary(self, update_op: ops.Operation):
"""Check whether a specific update_op is an auxiliary op that should not be considered."""
# Skip the AssignSub in AdamWeightDecay
Expand Down
9 changes: 6 additions & 3 deletions autodist/kernel/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def transform(self):
graph_item, self._strategy.node_config = VariablePartitioner.apply(self._strategy.node_config, graph_item)

visualization_util.log_graph(graph=graph_item.graph, name='1-after-partition')

# Create Synchronizers for each node in the strategy
self._initialize_synchronizers()

Expand Down Expand Up @@ -146,8 +145,10 @@ def _in_graph_apply(self, graph_item: GraphItem):
GraphItem
"""
new_graph_item = graph_item
new_graph_item.start_loop_optimize()
for var_name, syncer in self._synchronizers.items():
new_graph_item = syncer.in_graph_apply(new_graph_item, var_name)
new_graph_item = syncer.in_graph_apply(new_graph_item, var_name, optimize = True)
new_graph_item.end_loop_optimize()
return new_graph_item

def _between_graph_apply(self, multi_gpu_graph_item: GraphItem):
Expand All @@ -161,8 +162,10 @@ def _between_graph_apply(self, multi_gpu_graph_item: GraphItem):
GraphItem
"""
new_graph_item = multi_gpu_graph_item
new_graph_item.start_loop_optimize()
for var_name, syncer in self._synchronizers.items():
new_graph_item = syncer.between_graph_apply(new_graph_item, var_name)
new_graph_item = syncer.between_graph_apply(new_graph_item, var_name, optimize = True)
new_graph_item.end_loop_optimize()
self._prune_colocation_groups(new_graph_item)
# TODO: make this work
# update_shard_values_for_worker(num_workers, worker_id)
Expand Down
27 changes: 20 additions & 7 deletions autodist/kernel/synchronization/all_reduce_synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, config: synchronizers_pb2.AllReduceSynchronizer):
self._group = config.group
super().__init__()

def in_graph_apply(self, graph_item, var_name):
def in_graph_apply(self, graph_item, var_name, optimize = False):
"""
Perform in-graph synchronization based on AllReduce and TensorFlow Collective Ops.

Expand All @@ -75,6 +75,7 @@ def in_graph_apply(self, graph_item, var_name):
Args:
graph_item (graph_item.GraphItem): the graph_item to be distributed
var_name (str): the corresponded variable name
optimize: True if this is iteratively called

Returns:
graph_item.GraphItem: The new graph
Expand All @@ -88,7 +89,11 @@ def in_graph_apply(self, graph_item, var_name):

# Throw an error if the variable is sparse
master_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(0))
grad, _, _ = graph_item.var_op_name_to_grad_info[master_op_name]
if optimize:
graph_item.updated = True
grad, _, _ = graph_item.var_op_name_to_grad_info_optimize[master_op_name]
else:
grad, _, _ = graph_item.var_op_name_to_grad_info[master_op_name]
with item.graph.as_default():
self._share_initializer(item, var_op_name, master_replica=0)
if isinstance(grad, ops.IndexedSlices):
Expand All @@ -97,7 +102,7 @@ def in_graph_apply(self, graph_item, var_name):
self._collect_dense_gradients(item, var_op_name)
return item

def _collect_dense_gradients(self, graph_item, var_op_name):
def _collect_dense_gradients(self, graph_item, var_op_name, optimize = False):
"""Append collective ops after the gradient is calculated."""
if self.num_replicas * self.num_workers <= 1:
raise ValueError('CollectiveOps requires collective group size > 1')
Expand All @@ -115,7 +120,11 @@ def _collect_dense_gradients(self, graph_item, var_op_name):

for i in range(0, self.num_replicas):
op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i))
grad, _, _ = graph_item.var_op_name_to_grad_info[op_name]
if optimize:
graph_item.updated = True
grad, _, _ = graph_item.var_op_name_to_grad_info_optimize[op_name]
else:
grad, _, _ = graph_item.var_op_name_to_grad_info[op_name]
# TODO (Tairui): (3) Merge of reduction for performance
grad_consumers = get_consumers(grad.op) # this line must happen before the reduction

Expand All @@ -126,7 +135,7 @@ def _collect_dense_gradients(self, graph_item, var_op_name):
update_consumers(grad_consumers, grad, reduced_grad)
# TODO(Hao): update grad, target pair here or not?

def _collect_sparse_gradients(self, graph_item, var_op_name):
def _collect_sparse_gradients(self, graph_item, var_op_name, optimize = False):
"""Append collective ops after the gradient is calculated."""
if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value:
raise NotImplementedError('Currently the collective NCCL AllGather is not supported in TensorFlow release.'
Expand All @@ -140,7 +149,11 @@ def _collect_sparse_gradients(self, graph_item, var_op_name):
raise ValueError('CollectiveOps requires collective group size > 1')
for i in range(0, self.num_replicas):
op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i))
grad, _, _ = graph_item.var_op_name_to_grad_info[op_name]
if optimize:
graph_item.updated = True
grad, _, _ = graph_item.var_op_name_to_grad_info_optimize[op_name]
else:
grad, _, _ = graph_item.var_op_name_to_grad_info[op_name]
# TODO (Tairui): (3) Merge of reduction for performance
indices_c_ops = grad.indices.consumers()
indices_cc_ops = get_control_consumers(grad.indices.op)
Expand Down Expand Up @@ -192,6 +205,6 @@ def _share_initializer(self, graph_item, var_op_name, master_replica=0):
init_assign_op._update_input(1, master_init_tensor)

# pylint: disable=no-self-use
def between_graph_apply(self, graph_item, var_name):
def between_graph_apply(self, graph_item, var_name, optimize=False):
"""Allreduce synchronizer will do nothing in between-graph synchronization."""
return graph_item
32 changes: 25 additions & 7 deletions autodist/kernel/synchronization/ps_synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ def __init__(self, config: synchronizers_pb2.PSSynchronizer):
self._var_op_to_accum_apply_op = {}
super().__init__()

def in_graph_apply(self, graph_item, var_name):
def in_graph_apply(self, graph_item, var_name, optimize = False):
"""
Apply in-graph ps synchronization.

Args:
graph_item: the old graph item
var_name: the variable name w/o replica prefix
optimize: True if this is iteratively called

Returns:
graph_item.GraphItem
Expand All @@ -80,11 +81,17 @@ def in_graph_apply(self, graph_item, var_name):
master_replica_index = 0

with item.graph.as_default():
self._prune_control_dependencies(item, var_op_name, master_replica=master_replica_index)
self._prune_control_dependencies(item, var_op_name, master_replica=master_replica_index, optimize=optimize)
self._share_variable(item, var_op_name, master_replica=master_replica_index)
master_var_name = ops.prepend_name_scope(var_name, replica_prefix(master_replica_index))
master_var_op_name = get_op_name(master_var_name)
grad, target, update_op = item.var_op_name_to_grad_info[master_var_op_name]
if optimize:
grad, target, update_op = item.var_op_name_to_grad_info_optimize[master_var_op_name]
item.var_quried.append(master_var_op_name)
else:
grad, target, update_op = item.var_op_name_to_grad_info[master_var_op_name]
#print(grad, target, update_op,master_var_op_name,master_var_name)
#assert False
agg_grad = self._aggregate_gradients(item, old_update_op=update_op, old_grad=grad, old_target=target)

# update grad_target_pair and variable info
Expand Down Expand Up @@ -208,7 +215,7 @@ def ctrl_consumers(op):
raise RuntimeError("Incorrect old_grad.")
return agg_grad

def _prune_control_dependencies(self, graph_item, var_op_name, master_replica=0):
def _prune_control_dependencies(self, graph_item, var_op_name, master_replica=0, optimize = False):
"""
Prune the control dependencies between the train_op on non-master replica and update op.

Expand All @@ -223,7 +230,11 @@ def _prune_control_dependencies(self, graph_item, var_op_name, master_replica=0)
if i == master_replica:
continue
this_var_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i))
_, _, update_op = graph_item.var_op_name_to_grad_info[this_var_op_name]
if optimize:
graph_item.updated = True
_, _, update_op = graph_item.var_op_name_to_grad_info_optimize[this_var_op_name]
else:
_, _, update_op = graph_item.var_op_name_to_grad_info[this_var_op_name]
source_op = self._get_optimizer_source_op(update_op)
remove_from_control_consumers(get_control_consumers(source_op), source_op)

Expand All @@ -245,13 +256,14 @@ def _get_optimizer_source_op(update_op):

_BETWEEN_GRAPH_APPLY_SCOPE = 'autodist-between'.lower()

def between_graph_apply(self, graph_item, var_name):
def between_graph_apply(self, graph_item, var_name, optimize = False):
"""
Apply between-graph synchronization to the target ops in the graph.

Args:
graph_item: The current graph.
var_name: the variable to be synchronized.
optimize: True if iteratively called

Returns:
graph_item.GraphItem: updated graph item.
Expand All @@ -261,7 +273,12 @@ def between_graph_apply(self, graph_item, var_name):
item = graph_item
# here the variable on replica:0 has been shared, so the original var_name won't work
var_op_name = ops.prepend_name_scope(get_op_name(var_name), replica_prefix(0))
gradient, target, update_op = item.var_op_name_to_grad_info[var_op_name]
if optimize:
item.updated = True
gradient, target, update_op = item.var_op_name_to_grad_info_optimize[var_op_name]
item.var_quried.append(var_op_name)
else:
gradient, target, update_op = item.var_op_name_to_grad_info[var_op_name]
with item.graph.as_default():
proxy = self._create_proxy(item, gradient, target) if self._local_replication else None
if proxy:
Expand Down Expand Up @@ -296,6 +313,7 @@ def add_sync_op(self, graph_item, var_update_op, variable_replicator=None):
this_worker_cpu = this_worker_cpu.replace(device_type='CPU', device_index=0)

var_op = var_update_op.inputs[UPDATE_OP_VAR_POS].op
#print(graph_item.trainable_var_op_to_var)
is_trainable = var_op in graph_item.trainable_var_op_to_var
source_op = self._get_optimizer_source_op(var_update_op)
cc = get_control_consumers(source_op)
Expand Down
13 changes: 13 additions & 0 deletions examples/benchmark/bert_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this file? no need to include it in repo at this moment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with the tfrecord file below.

"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"max_position_embeddings": 512,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"type_vocab_size": 2,
"vocab_size": 30522
}
Binary file added examples/benchmark/tf_examples.tfrecord
Binary file not shown.