Skip to content

Commit

Permalink
Merge branch 'main' into yutji/hpl-rocblas
Browse files Browse the repository at this point in the history
  • Loading branch information
yukirora authored Nov 22, 2023
2 parents 8cc7674 + 79089b6 commit 73029c7
Show file tree
Hide file tree
Showing 10 changed files with 419 additions and 110 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/cuda-unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ steps:
- script: |
SB_MICRO_PATH=$PWD python3 setup.py test
displayName: Run unit tests
timeoutInMinutes: 30
timeoutInMinutes: 60
- script: |
bash <(curl -s https://codecov.io/bash) -cF cuda-unit-test
displayName: Report coverage results
Expand Down
2 changes: 1 addition & 1 deletion dockerfile/rocm5.0.x.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ RUN echo PATH="$PATH" > /etc/environment && \
WORKDIR ${SB_HOME}

ADD third_party third_party
RUN make -C third_party rocm
RUN make -C third_party rocm -o rocm_hipblaslt

ADD . .
RUN python3 -m pip install --upgrade setuptools==65.7 && \
Expand Down
2 changes: 1 addition & 1 deletion dockerfile/rocm5.1.x.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ RUN echo PATH="$PATH" > /etc/environment && \
WORKDIR ${SB_HOME}

ADD third_party third_party
RUN make ROCBLAS_BRANCH=release/rocm-rel-5.1 -C third_party rocm
RUN make ROCBLAS_BRANCH=release/rocm-rel-5.1 -C third_party rocm -o rocm_hipblaslt

ADD . .
RUN python3 -m pip install --no-cache-dir .[amdworker] && \
Expand Down
4 changes: 4 additions & 0 deletions superbench/benchmarks/micro_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from superbench.benchmarks.micro_benchmarks.computation_communication_overlap import ComputationCommunicationOverlap
from superbench.benchmarks.micro_benchmarks.cublas_function import CublasBenchmark
from superbench.benchmarks.micro_benchmarks.blaslt_function_base import BlasLtBaseBenchmark
from superbench.benchmarks.micro_benchmarks.cublaslt_function import CublasLtBenchmark
from superbench.benchmarks.micro_benchmarks.hipblaslt_function import HipBlasLtBenchmark
from superbench.benchmarks.micro_benchmarks.cuda_gemm_flops_performance import CudaGemmFlopsBenchmark
from superbench.benchmarks.micro_benchmarks.cuda_memory_bw_performance import CudaMemBwBenchmark
from superbench.benchmarks.micro_benchmarks.cuda_nccl_bw_performance import CudaNcclBwBenchmark
Expand Down Expand Up @@ -37,6 +39,7 @@
from superbench.benchmarks.micro_benchmarks.directx_gemm_flops_performance import DirectXGPUCoreFlops

__all__ = [
'BlasLtBaseBenchmark',
'ComputationCommunicationOverlap',
'CpuMemBwLatencyBenchmark',
'CpuHplBenchmark',
Expand All @@ -49,6 +52,7 @@
'CudnnBenchmark',
'DiskBenchmark',
'DistInference',
'HipBlasLtBenchmark',
'GPCNetBenchmark',
'GemmFlopsBenchmark',
'GpuBurnBenchmark',
Expand Down
141 changes: 141 additions & 0 deletions superbench/benchmarks/micro_benchmarks/blaslt_function_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the BLASLt GEMM Base Class."""
import itertools

from superbench.common.utils import logger
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke


def mrange(start, stop=-1, multiplication_factor=2, symbol='x'):
"""Range constructor with multiplication factor.
Args:
start (int): Start number.
stop (int, optional): Stop number. Defaults to -1.
multiplication_factor (int, optional): Multiplication factor. Defaults to 2.
symbol (str, optional): Symbol. Defaults to 'x' (multiplication).
Yields:
int: number in the range.
"""
if symbol == 'x':
while True:
yield start
start *= multiplication_factor
if start > stop or start == 0 or multiplication_factor < 2:
break
elif symbol == '+':
while True:
yield start
start = start + multiplication_factor
if start > stop or start == 0 or multiplication_factor < 1:
break
else:
raise ValueError(f'Invalid symbol {symbol}.')


def validate_mrange(string):
"""Validate mrange string in format start[[:stop]:multiplication_factor].
Args:
string (str): mrange string.
Returns:
bool: whether the mrange is expected.
"""
nums = string.split(':')
if len(nums) > 3:
return False

if len(nums) < 3:
return all(x.isdigit() for x in nums)
return nums[0].isdigit() and nums[1].isdigit() and (nums[2].lstrip('+').isdigit() or nums[2].lstrip('x').isdigit())


class BlasLtBaseBenchmark(MicroBenchmarkWithInvoke):
"""The BLASLt GEMM Base class."""
def __init__(self, name, parameters=''):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)

def add_parser_arguments(self):
"""Add the specified arguments."""
super().add_parser_arguments()

self._parser.add_argument(
'--shapes',
type=str,
nargs='+',
default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]],
help='Shapes in m,n,k format. Support format start:stop:multiplication_factor, e.g., 16:128:2.',
)
self._parser.add_argument(
'--batch',
type=str,
default='0',
required=False,
help=(
'Batch size for strided batch GEMM, set 0 to disable.'
' Support format start:stop:multiplication_factor, e.g., 16:128:2.'
),
)
self._parser.add_argument(
'--num_warmup',
type=int,
default=20,
required=False,
help='Number of warm up steps.',
)
self._parser.add_argument(
'--num_steps',
type=int,
default=50,
required=False,
help='Number of steps to measure.',
)

def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.
Return:
True if _preprocess() succeed.
"""
if not super()._preprocess():
return False

if not validate_mrange(self._args.batch):
logger.error(f'Invalid batch size {self._args.batch}.')
return False

for _in_type in self._args.in_types:
if _in_type not in self._in_types:
logger.error(f'Invalid input type {_in_type}.')
return False

self._shapes_to_run = []
for _in_type in self._args.in_types:
for _b in mrange(*map(int, self._args.batch.split(':'))):
for shape in self._args.shapes:
shape_list = shape.replace(',', ' ').split()
if len(shape_list) != 3 or not all(validate_mrange(x) for x in shape_list):
logger.error(f'Invalid shape {shape}.')
return False
for _m, _n, _k in itertools.product(
*map(
lambda shape: mrange(
*map(lambda dim: int(dim.lstrip('+').lstrip('x')), shape.split(':')),
symbol=shape.split(':')[2][0]
if len(shape.split(':')) == 3 and any([i in shape for i in ['+', 'x']]) else 'x'
), shape_list
)
):
self._shapes_to_run.append((_m, _n, _k, _b, _in_type))

return True
93 changes: 7 additions & 86 deletions superbench/benchmarks/micro_benchmarks/cublaslt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
"""Module of the cuBLASLt GEMM benchmark."""

import os
import itertools

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
from superbench.benchmarks.micro_benchmarks import BlasLtBaseBenchmark


class CublasLtBenchmark(MicroBenchmarkWithInvoke):
class CublasLtBenchmark(BlasLtBaseBenchmark):
"""The cuBLASLt GEMM benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.
Expand All @@ -25,72 +24,10 @@ def __init__(self, name, parameters=''):
self._bin_name = 'cublaslt_gemm'
self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2', 'int8']

def mrange(self, start, stop=-1, multiplication_factor=2):
"""Range constructor with multiplication factor.
Args:
start (int): Start number.
stop (int, optional): Stop number. Defaults to -1.
multiplication_factor (int, optional): Multiplication factor. Defaults to 2.
Yields:
int: number in the range.
"""
while True:
yield start
start *= multiplication_factor
if start > stop or start == 0 or multiplication_factor < 2:
break

def validate_mrange(self, string):
"""Validate mrange string in format start[[:stop]:multiplication_factor].
Args:
string (str): mrange string.
Returns:
bool: whether the mrange is expected.
"""
nums = string.split(':')
if len(nums) > 3:
return False
return bool(all(x.isdigit() for x in nums))

def add_parser_arguments(self):
"""Add the specified arguments."""
super().add_parser_arguments()

self._parser.add_argument(
'--shapes',
type=str,
nargs='+',
default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]],
help='Shapes in m,n,k format. Support format start:stop:multiplication_factor, e.g., 16:128:2.',
)
self._parser.add_argument(
'--batch',
type=str,
default='0',
required=False,
help=(
'Batch size for strided batch GEMM, set 0 to disable.'
' Support format start:stop:multiplication_factor, e.g., 16:128:2.'
),
)
self._parser.add_argument(
'--num_warmup',
type=int,
default=20,
required=False,
help='Number of warm up steps.',
)
self._parser.add_argument(
'--num_steps',
type=int,
default=50,
required=False,
help='Number of steps to measure.',
)
self._parser.add_argument(
'--in_types',
type=str,
Expand All @@ -111,28 +48,12 @@ def _preprocess(self):

self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)

if not self.validate_mrange(self._args.batch):
logger.error(f'Invalid batch size {self._args.batch}.')
return False

self._commands = []
for _in_type in self._args.in_types:
if _in_type not in self._in_types:
logger.error(f'Invalid input type {_in_type}.')
return False
for _b in self.mrange(*map(int, self._args.batch.split(':'))):
for shape in self._args.shapes:
shape_list = shape.replace(',', ' ').split()
if len(shape_list) != 3 or not all(self.validate_mrange(x) for x in shape_list):
logger.error(f'Invalid shape {shape}.')
return False
for _m, _n, _k in itertools.product(
*map(lambda shape: self.mrange(*map(int, shape.split(':'))), shape_list)
):
self._commands.append(
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
)
for _m, _n, _k, _b, _in_type in self._shapes_to_run:
self._commands.append(
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
)

return True

Expand Down
Loading

0 comments on commit 73029c7

Please sign in to comment.