Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
yzygitzh committed Dec 8, 2023
1 parent e393332 commit efb2cf1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
44 changes: 22 additions & 22 deletions superbench/benchmarks/micro_benchmarks/dist_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def add_parser_arguments(self):
super().add_parser_arguments()

self._parser.add_argument(
'--use_cpp_impl',
'--use_pytorch',
action='store_true',
required=False,
help='Whether to use cpp-based implementation.',
required=True,
help='Whether to use pytorch implementation. If not, cpp implementation will be used.',
)
self._parser.add_argument(
'--batch_size',
Expand Down Expand Up @@ -324,18 +324,7 @@ def _preprocess(self):
if not super()._preprocess():
return False

if self._args.use_cpp_impl:
# Assemble commands if cpp impl path
self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)

args = '-m %d -n %d -k %d' % (self._args.input_size, self._args.batch_size, self._args.hidden_size)
args += ' --alpha %g --beta %g' % (self._args.alpha, self._args.beta)
args += ' --num_layers %d --num_warmups %d --num_iters %d' % \
(self._args.num_layers, self._args.num_warmup, self._args.num_steps)
if self._args.use_cuda_graph:
args += ' --use_cuda_graph'
self._commands = ['%s %s' % (self.__bin_path, args)]
else:
if self._args.use_pytorch:
# Initialize PyTorch if pytorch impl path
if self._args.distributed_impl != DistributedImpl.DDP:
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
Expand Down Expand Up @@ -365,7 +354,18 @@ def _preprocess(self):
else:
self.__device = torch.device('cpu:{}'.format(self.__local_rank))
self.__cuda_available = False
else:
# Assemble commands if cpp impl path
self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)

args = '-m %d -n %d -k %d' % (self._args.input_size, self._args.batch_size, self._args.hidden_size)
args += ' --alpha %g --beta %g' % (self._args.alpha, self._args.beta)
args += ' --num_layers %d --num_warmups %d --num_iters %d' % \
(self._args.num_layers, self._args.num_warmup, self._args.num_steps)
if self._args.use_cuda_graph:
args += ' --use_cuda_graph'
self._commands = ['%s %s' % (self.__bin_path, args)]

return True

def _prepare_model(
Expand Down Expand Up @@ -445,12 +445,7 @@ def _benchmark(self):
Return:
True if _benchmark succeeds.
"""
if self._args.use_cpp_impl:
# Execute commands if cpp impl path
if not super()._benchmark():
return False
return True
else:
if self._args.use_pytorch:
# Execute PyTorch model if pytorch impl path
batch_size = self._args.batch_size
input_size = self._args.input_size
Expand Down Expand Up @@ -485,6 +480,11 @@ def _benchmark(self):

# Process data and return
return self._process_data(step_times)
else:
# Execute commands if cpp impl path
if not super()._benchmark():
return False
return True

def _process_raw_result(self, cmd_idx, raw_output):
"""Function to parse raw results and save the summarized results.
Expand Down Expand Up @@ -528,7 +528,7 @@ def _postprocess(self):
if not super()._postprocess():
return False

if not self._args.use_cpp_impl:
if self._args.use_pytorch:
try:
torch.distributed.destroy_process_group()
except BaseException as e:
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmarks/micro_benchmarks/test_dist_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_pytorch_dist_inference_normal():
assert (benchmark.type == BenchmarkType.MICRO)

# Check predefined parameters of dist-inference benchmark.
assert (benchmark._args.use_cpp_impl is False)
assert (benchmark._args.use_pytorch is True)
assert (benchmark._args.batch_size == 64)
assert (benchmark._args.input_size == 1024)
assert (benchmark._args.hidden_size == 1024)
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_pytorch_dist_inference_fake_distributed():
assert (benchmark.type == BenchmarkType.MICRO)

# Check predefined parameters of dist-inference benchmark.
assert (benchmark._args.use_cpp_impl is False)
assert (benchmark._args.use_pytorch is True)
assert (benchmark._args.batch_size == 64)
assert (benchmark._args.input_size == 1024)
assert (benchmark._args.hidden_size == 1024)
Expand Down

0 comments on commit efb2cf1

Please sign in to comment.