Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarks: Micro benchmark - Add hipBLASLt function benchmark #576

Merged
merged 7 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dockerfile/rocm5.0.x.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ RUN echo PATH="$PATH" > /etc/environment && \
WORKDIR ${SB_HOME}

ADD third_party third_party
RUN make -C third_party rocm
RUN make -C third_party rocm -o rocm_hipblaslt

ADD . .
RUN python3 -m pip install --upgrade setuptools==65.7 && \
Expand Down
2 changes: 1 addition & 1 deletion dockerfile/rocm5.1.x.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ RUN echo PATH="$PATH" > /etc/environment && \
WORKDIR ${SB_HOME}

ADD third_party third_party
RUN make ROCBLAS_BRANCH=release/rocm-rel-5.1 -C third_party rocm
RUN make ROCBLAS_BRANCH=release/rocm-rel-5.1 -C third_party rocm -o rocm_hipblaslt

ADD . .
RUN python3 -m pip install --no-cache-dir .[amdworker] && \
Expand Down
4 changes: 4 additions & 0 deletions superbench/benchmarks/micro_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from superbench.benchmarks.micro_benchmarks.computation_communication_overlap import ComputationCommunicationOverlap
from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark
from superbench.benchmarks.micro_benchmarks.blaslt_function_base import BlasLtBaseBenchmark
from superbench.benchmarks.micro_benchmarks.cublaslt_function import CublasLtBenchmark
from superbench.benchmarks.micro_benchmarks.hipblas_function import HipBlasLtBenchmark
from superbench.benchmarks.micro_benchmarks.cuda_gemm_flops_performance import CudaGemmFlopsBenchmark
from superbench.benchmarks.micro_benchmarks.cuda_memory_bw_performance import CudaMemBwBenchmark
from superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance import CudaNcclBwBenchmark
Expand Down Expand Up @@ -37,6 +39,7 @@
from superbench.benchmarks.micro_benchmarks.directx_gemm_flops_performance import DirectXGPUCoreFlops

__all__ = [
'BlasLtBaseBenchmark',
'ComputationCommunicationOverlap',
'CpuMemBwLatencyBenchmark',
'CpuHplBenchmark',
Expand All @@ -49,6 +52,7 @@
'CudnnBenchmark',
'DiskBenchmark',
'DistInference',
'HipBlasLtBenchmark',
'GPCNetBenchmark',
'GemmFlopsBenchmark',
'GpuBurnBenchmark',
Expand Down
141 changes: 141 additions & 0 deletions superbench/benchmarks/micro_benchmarks/blaslt_function_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the BLASLt GEMM Base Class."""
import itertools

from superbench.common.utils import logger
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke


def mrange(start, stop=-1, multiplication_factor=2, symbol='x'):
"""Range constructor with multiplication factor.

Args:
start (int): Start number.
stop (int, optional): Stop number. Defaults to -1.
multiplication_factor (int, optional): Multiplication factor. Defaults to 2.
symbol (str, optional): Symbol. Defaults to 'x' (multiplication).

Yields:
int: number in the range.
"""
if symbol == 'x':
while True:
yield start
start *= multiplication_factor
if start > stop or start == 0 or multiplication_factor < 2:
break
elif symbol == '+':
while True:
yield start
start = start + multiplication_factor
if start > stop or start == 0 or multiplication_factor < 1:
break
else:
raise ValueError(f'Invalid symbol {symbol}.')


def validate_mrange(string):
"""Validate mrange string in format start[[:stop]:multiplication_factor].

Args:
string (str): mrange string.

Returns:
bool: whether the mrange is expected.
"""
nums = string.split(':')
if len(nums) > 3:
return False

if len(nums) < 3:
return all(x.isdigit() for x in nums)
return nums[0].isdigit() and nums[1].isdigit() and (nums[2].lstrip('+').isdigit() or nums[2].lstrip('x').isdigit())


class BlasLtBaseBenchmark(MicroBenchmarkWithInvoke):
"""The BLASLt GEMM Base class."""
def __init__(self, name, parameters=''):
"""Constructor.

Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)

def add_parser_arguments(self):
"""Add the specified arguments."""
super().add_parser_arguments()

self._parser.add_argument(
'--shapes',
type=str,
nargs='+',
default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]],
help='Shapes in m,n,k format. Support format start:stop:multiplication_factor, e.g., 16:128:2.',
)
self._parser.add_argument(
'--batch',
type=str,
default='0',
required=False,
help=(
'Batch size for strided batch GEMM, set 0 to disable.'
' Support format start:stop:multiplication_factor, e.g., 16:128:2.'
),
)
self._parser.add_argument(
'--num_warmup',
type=int,
default=20,
required=False,
help='Number of warm up steps.',
)
self._parser.add_argument(
'--num_steps',
type=int,
default=50,
required=False,
help='Number of steps to measure.',
)

def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.

Return:
True if _preprocess() succeed.
"""
if not super()._preprocess():
return False

if not validate_mrange(self._args.batch):
logger.error(f'Invalid batch size {self._args.batch}.')
return False

for _in_type in self._args.in_types:
if _in_type not in self._in_types:
logger.error(f'Invalid input type {_in_type}.')
return False

self._shapes_to_run = []
for _in_type in self._args.in_types:
for _b in mrange(*map(int, self._args.batch.split(':'))):
for shape in self._args.shapes:
shape_list = shape.replace(',', ' ').split()
if len(shape_list) != 3 or not all(validate_mrange(x) for x in shape_list):
logger.error(f'Invalid shape {shape}.')
return False
for _m, _n, _k in itertools.product(
*map(
lambda shape: mrange(
*map(lambda dim: int(dim.lstrip('+').lstrip('x')), shape.split(':')),
symbol=shape.split(':')[2][0]
if len(shape.split(':')) == 3 and any([i in shape for i in ['+', 'x']]) else 'x'
), shape_list
)
):
self._shapes_to_run.append((_m, _n, _k, _b, _in_type))

return True
93 changes: 7 additions & 86 deletions superbench/benchmarks/micro_benchmarks/cublaslt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
"""Module of the cuBLASLt GEMM benchmark."""

import os
import itertools

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
from superbench.benchmarks.micro_benchmarks import BlasLtBaseBenchmark


class CublasLtBenchmark(MicroBenchmarkWithInvoke):
class CublasLtBenchmark(BlasLtBaseBenchmark):
"""The cuBLASLt GEMM benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.
Expand All @@ -25,72 +24,10 @@ def __init__(self, name, parameters=''):
self._bin_name = 'cublaslt_gemm'
self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2', 'int8']

def mrange(self, start, stop=-1, multiplication_factor=2):
"""Range constructor with multiplication factor.

Args:
start (int): Start number.
stop (int, optional): Stop number. Defaults to -1.
multiplication_factor (int, optional): Multiplication factor. Defaults to 2.

Yields:
int: number in the range.
"""
while True:
yield start
start *= multiplication_factor
if start > stop or start == 0 or multiplication_factor < 2:
break

def validate_mrange(self, string):
"""Validate mrange string in format start[[:stop]:multiplication_factor].

Args:
string (str): mrange string.

Returns:
bool: whether the mrange is expected.
"""
nums = string.split(':')
if len(nums) > 3:
return False
return bool(all(x.isdigit() for x in nums))

def add_parser_arguments(self):
"""Add the specified arguments."""
super().add_parser_arguments()

self._parser.add_argument(
'--shapes',
type=str,
nargs='+',
default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]],
help='Shapes in m,n,k format. Support format start:stop:multiplication_factor, e.g., 16:128:2.',
)
self._parser.add_argument(
'--batch',
type=str,
default='0',
required=False,
help=(
'Batch size for strided batch GEMM, set 0 to disable.'
' Support format start:stop:multiplication_factor, e.g., 16:128:2.'
),
)
self._parser.add_argument(
'--num_warmup',
type=int,
default=20,
required=False,
help='Number of warm up steps.',
)
self._parser.add_argument(
'--num_steps',
type=int,
default=50,
required=False,
help='Number of steps to measure.',
)
self._parser.add_argument(
'--in_types',
type=str,
Expand All @@ -111,28 +48,12 @@ def _preprocess(self):

self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)

if not self.validate_mrange(self._args.batch):
logger.error(f'Invalid batch size {self._args.batch}.')
return False

self._commands = []
for _in_type in self._args.in_types:
if _in_type not in self._in_types:
logger.error(f'Invalid input type {_in_type}.')
return False
for _b in self.mrange(*map(int, self._args.batch.split(':'))):
for shape in self._args.shapes:
shape_list = shape.replace(',', ' ').split()
if len(shape_list) != 3 or not all(self.validate_mrange(x) for x in shape_list):
logger.error(f'Invalid shape {shape}.')
return False
for _m, _n, _k in itertools.product(
*map(lambda shape: self.mrange(*map(int, shape.split(':'))), shape_list)
):
self._commands.append(
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
)
for _m, _n, _k, _b, _in_type in self._shapes_to_run:
self._commands.append(
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
)

return True

Expand Down
Loading
Loading