diff --git a/.github/workflows/build-image.yml b/.github/workflows/build-image.yml index 94d53d9a5..ed821d654 100644 --- a/.github/workflows/build-image.yml +++ b/.github/workflows/build-image.yml @@ -29,18 +29,22 @@ jobs: dockerfile: cuda12.2 tags: superbench/main:cuda12.2 runner: [self-hosted, rocm-build] + build_args: "NUM_MAKE_JOBS=64" - name: cuda11.1.1 dockerfile: cuda11.1.1 tags: superbench/main:cuda11.1.1,superbench/superbench:latest runner: ubuntu-latest + build_args: "NUM_MAKE_JOBS=8" - name: rocm5.7 dockerfile: rocm5.7.x tags: superbench/main:rocm5.7 runner: [self-hosted, rocm-build] + build_args: "NUM_MAKE_JOBS=64" - name: rocm6.0 dockerfile: rocm6.0.x tags: superbench/main:rocm6.0 runner: [self-hosted, rocm-build] + build_args: "NUM_MAKE_JOBS=64" steps: - name: Checkout uses: actions/checkout@v2 @@ -80,7 +84,7 @@ jobs: fi DOCKERFILE=dockerfile/${{ matrix.dockerfile }}.dockerfile - BUILD_ARGS="NUM_MAKE_JOBS=8" + BUILD_ARGS=${{ matrix.build_args }} if [[ "${{ matrix.extra_args }}" ]]; then BUILD_ARGS="${BUILD_ARGS} ${{ matrix.extra_args }}" fi diff --git a/dockerfile/rocm5.7.x.dockerfile b/dockerfile/rocm5.7.x.dockerfile index ee762e9ee..ce87e9fc6 100644 --- a/dockerfile/rocm5.7.x.dockerfile +++ b/dockerfile/rocm5.7.x.dockerfile @@ -17,6 +17,7 @@ RUN apt-get update && \ apt-get -q install -y --no-install-recommends \ autoconf \ automake \ + bc \ build-essential \ curl \ dmidecode \ @@ -27,6 +28,7 @@ RUN apt-get update && \ libaio-dev \ libboost-program-options-dev \ libcap2 \ + libcurl4-openssl-dev \ libnuma-dev \ libpci-dev \ libssl-dev \ @@ -38,6 +40,7 @@ RUN apt-get update && \ openssh-client \ openssh-server \ pciutils \ + python3-mpi4py \ rsync \ sudo \ util-linux \ @@ -46,11 +49,11 @@ RUN apt-get update && \ && \ rm -rf /tmp/* -ARG NUM_MAKE_JOBS=16 +ARG NUM_MAKE_JOBS= # Check if CMake is installed and its version RUN cmake_version=$(cmake --version 2>/dev/null | grep -oP "(?<=cmake version )(\d+\.\d+)" || echo "0.0") && \ - required_version="3.26.4" && \ + required_version="3.24.1" && \ if [ "$(printf "%s\n" "$required_version" "$cmake_version" | sort -V | head -n 1)" != "$required_version" ]; then \ echo "existing cmake version is ${cmake_version}" && \ cd /tmp && \ @@ -100,21 +103,9 @@ RUN if ! command -v ofed_info >/dev/null 2>&1; then \ rm -rf MLNX_OFED_LINUX-${OFED_VERSION}* ; \ fi -# Install UCX -ENV UCX_VERSION=1.14.1 -RUN if [ -z "$(ls -A /opt/ucx)" ]; then \ - echo "/opt/ucx is empty. Installing UCX..."; \ - cd /tmp && \ - git clone https://github.com/openucx/ucx.git -b v${UCX_VERSION} && \ - cd ucx && \ - ./autogen.sh && \ - mkdir build && \ - cd build && \ - ../configure -prefix=$UCX_DIR --with-rocm=/opt/rocm --without-knem && \ - make -j $(nproc) && make -j $(nproc) install && rm -rf /tmp/ucx-${UCX_VERSION} ; \ - else \ - echo "/opt/ucx is not empty. Skipping UCX installation."; \ - fi +# Add target file to help determine which device(s) to build for +ENV ROCM_PATH=/opt/rocm +RUN bash -c 'echo -e "gfx90a:xnack-\ngfx90a:xnac+\ngfx940\ngfx941\ngfx942\ngfx1030\ngfx1100\ngfx1101\ngfx1102\n" >> ${ROCM_PATH}/bin/target.lst' # Install OpenMPI ENV OPENMPI_VERSION=4.1.x @@ -127,7 +118,7 @@ RUN [ -d /usr/local/bin/mpirun ] || { \ ./autogen.pl && \ mkdir build && \ cd build && \ - ../configure --prefix=/usr/local --enable-orterun-prefix-by-default --enable-mpirun-prefix-by-default --enable-prte-prefix-by-default --enable-mca-no-build=btl-uct --with-ucx=/opt/ucx --with-rocm=/opt/rocm && \ + ../configure --prefix=/usr/local --enable-orterun-prefix-by-default --enable-mpirun-prefix-by-default --enable-prte-prefix-by-default --with-rocm=/opt/rocm && \ make -j $(nproc) && \ make -j $(nproc) install && \ ldconfig && \ @@ -148,12 +139,18 @@ RUN cd /opt/ && \ cd rccl && \ mkdir build && \ cd build && \ - CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_PREFIX_PATH=/opt/rocm/ .. && \ + CXX=/opt/rocm/bin/hipcc cmake -DHIP_COMPILER=clang -DCMAKE_BUILD_TYPE=Release -DCMAKE_VERBOSE_MAKEFILE=1 \ + -DCMAKE_PREFIX_PATH="${ROCM_PATH}/hsa;${ROCM_PATH}/hip;${ROCM_PATH}/share/rocm/cmake/;${ROCM_PATH}" \ + .. && \ make -j${NUM_MAKE_JOBS} +# Install AMD SMI Python Library +RUN cd /opt/rocm/share/amd_smi && \ + python3 -m pip install --user . + ENV PATH="/opt/superbench/bin:/usr/local/bin/:/opt/rocm/hip/bin/:/opt/rocm/bin/:${PATH}" \ LD_PRELOAD="/opt/rccl/build/librccl.so:$LD_PRELOAD" \ - LD_LIBRARY_PATH="/opt/ucx/lib:/usr/local/lib/:/opt/rocm/lib:${LD_LIBRARY_PATH}" \ + LD_LIBRARY_PATH="/usr/local/lib/:/opt/rocm/lib:${LD_LIBRARY_PATH}" \ SB_HOME=/opt/superbench \ SB_MICRO_PATH=/opt/superbench \ ANSIBLE_DEPRECATION_WARNINGS=FALSE \ @@ -163,13 +160,17 @@ RUN echo PATH="$PATH" > /etc/environment && \ echo LD_LIBRARY_PATH="$LD_LIBRARY_PATH" >> /etc/environment && \ echo SB_MICRO_PATH="$SB_MICRO_PATH" >> /etc/environment +RUN apt install rocm-cmake -y && \ + python3 -m pip install --upgrade pip wheel setuptools==65.7 + WORKDIR ${SB_HOME} +ADD third_party third_party +RUN make RCCL_HOME=/opt/rccl/build/ MPI_HOME=/usr/local ROCBLAS_BRANCH=release/rocm-rel-5.7.1.1 HIPBLASLT_BRANCH=release-staging/rocm-rel-5.7 ROCM_VER=rocm-5.5.0 -C third_party rocm -o cpu_hpl -o cpu_stream -o megatron_lm + ADD . . -RUN apt install rocm-cmake -y && \ - python3 -m pip install --upgrade pip wheel setuptools==65.7 && \ - python3 -m pip install .[amdworker] && \ +#ENV USE_HIPBLASLT_DATATYPE=1 +ENV CXX=/opt/rocm/bin/hipcc +RUN python3 -m pip install .[amdworker] && \ + make cppbuild && \ make postinstall -RUN make cppbuild -ADD third_party third_party -RUN make RCCL_HOME=/opt/rccl/build/ ROCBLAS_BRANCH=release/rocm-rel-5.7.1.1 HIPBLASLT_BRANCH=release-staging/rocm-rel-5.7 ROCM_VER=rocm-5.5.0 -C third_party rocm -o cpu_hpl -o cpu_stream -o megatron_lm diff --git a/docs/user-tutorial/benchmarks/micro-benchmarks.md b/docs/user-tutorial/benchmarks/micro-benchmarks.md index 5155be7b0..388bfa119 100644 --- a/docs/user-tutorial/benchmarks/micro-benchmarks.md +++ b/docs/user-tutorial/benchmarks/micro-benchmarks.md @@ -58,17 +58,18 @@ Large scale matmul operation using `torch.matmul` with one GPU. |--------------------------------|-----------|--------------------------------| | pytorch-matmul/nosharding_time | time (ms) | Time of pure matmul operation. | -### `cublaslt-gemm` +### `cublaslt-gemm` / `hipblaslt-gemm` #### Introduction -Measure the GEMM performance of [`cublasLtMatmul`](https://docs.nvidia.com/cuda/cublas/#cublasltmatmul). +Measure the GEMM performance of [`cublasLtMatmul`](https://docs.nvidia.com/cuda/cublas/#cublasltmatmul) or [`hipblasLt-bench`](https://github.com/ROCm/hipBLASLt/blob/develop/clients/benchmarks/README.md). #### Metrics -| Name | Unit | Description | -|----------------------------------------------------------|----------------|---------------------------------| -| cublaslt-gemm/${dtype}\_${batch}\_${m}\_${n}\_${k}_flops | FLOPS (TFLOPS) | TFLOPS of measured GEMM kernel. | +| Name | Unit | Description | +|-----------------------------------------------------------|----------------|---------------------------------| +| cublaslt-gemm/${dtype}\_${batch}\_${m}\_${n}\_${k}_flops | FLOPS (TFLOPS) | TFLOPS of measured GEMM kernel. | +| hipblaslt-gemm/${dtype}\_${batch}\_${m}\_${n}\_${k}_flops | FLOPS (TFLOPS) | TFLOPS of measured GEMM kernel. | ### `cublas-function` diff --git a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt index f29149e50..27a220f90 100644 --- a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt +++ b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt @@ -33,6 +33,11 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1") if(DEFINED ENV{USE_HIPBLASLT_DATATYPE}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLASLT_DATATYPE=1") + elseif(DEFINED ENV{USE_HIP_DATATYPE}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIP_DATATYPE=1") + endif() + if(DEFINED ENV{USE_HIPBLAS_COMPUTETYPE}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLAS_COMPUTETYPE=1") endif() target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device) else() diff --git a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu index 1db7e270f..c6d7cd033 100644 --- a/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu +++ b/superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu @@ -48,10 +48,18 @@ using cublasLtHalf = hipblasLtHalf; #if defined(USE_HIPBLASLT_DATATYPE) #define DIST_INF_HIP_DATATYPE_R_16F HIPBLASLT_R_16F #define DIST_INF_HIP_DATATYPE_R_32F HIPBLASLT_R_32F +#elif defined(USE_HIP_DATATYPE) +#define DIST_INF_HIP_DATATYPE_R_16F HIP_R_16F +#define DIST_INF_HIP_DATATYPE_R_32F HIP_R_32F #else #define DIST_INF_HIP_DATATYPE_R_16F HIPBLAS_R_16F #define DIST_INF_HIP_DATATYPE_R_32F HIPBLAS_R_32F #endif +#if defined(USE_HIPBLAS_COMPUTETYPE) +#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLAS_COMPUTE_32F +#else +#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLASLT_COMPUTE_F32 +#endif #else #include #include @@ -244,8 +252,10 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, DIST_INF_HIP_DATATYPE_R_16F, k, n, k)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, DIST_INF_HIP_DATATYPE_R_16F, k, n, k)); - CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F)); - CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F)); + CHECK_CUBLASLT_ERROR( + hipblasLtMatmulDescCreate(&matmul1, DIST_INF_HIP_COMPUTETYPE_F32, DIST_INF_HIP_DATATYPE_R_32F)); + CHECK_CUBLASLT_ERROR( + hipblasLtMatmulDescCreate(&matmul2, DIST_INF_HIP_COMPUTETYPE_F32, DIST_INF_HIP_DATATYPE_R_32F)); hipblasOperation_t trans = HIPBLAS_OP_N; CHECK_CUBLASLT_ERROR( diff --git a/superbench/benchmarks/micro_benchmarks/hipblaslt_function.py b/superbench/benchmarks/micro_benchmarks/hipblaslt_function.py index 508973777..afe220e68 100644 --- a/superbench/benchmarks/micro_benchmarks/hipblaslt_function.py +++ b/superbench/benchmarks/micro_benchmarks/hipblaslt_function.py @@ -103,7 +103,8 @@ def _process_raw_result(self, cmd_idx, raw_output): raise ValueError('Invalid result') self._result.add_result( - f'{self._precision_in_commands[cmd_idx]}_{fields[3]}_{"_".join(fields[4:7])}_flops', float(fields[-2]) + f'{self._precision_in_commands[cmd_idx]}_{fields[3]}_{"_".join(fields[4:7])}_flops', + float(fields[-2]) / 1000 ) except BaseException as e: self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE) diff --git a/superbench/benchmarks/micro_benchmarks/rocm_common.cmake b/superbench/benchmarks/micro_benchmarks/rocm_common.cmake index be60df127..1d2cc3934 100644 --- a/superbench/benchmarks/micro_benchmarks/rocm_common.cmake +++ b/superbench/benchmarks/micro_benchmarks/rocm_common.cmake @@ -45,8 +45,7 @@ message(STATUS "CMAKE HIP ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES}") if(EXISTS ${HIP_PATH}) # Search for hip in common locations - list(APPEND CMAKE_PREFIX_PATH ${HIP_PATH} ${ROCM_PATH}) - set(CMAKE_PREFIX_PATH /opt/rocm ROCM_PATH) + list(APPEND CMAKE_PREFIX_PATH ${HIP_PATH} ${ROCM_PATH} ${ROCM_PATH}/hsa ${ROCM_PATH}/hip ${ROCM_PATH}/share/rocm/cmake/) set(CMAKE_CXX_COMPILER "${HIP_PATH}/bin/hipcc") set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) set(CMAKE_MODULE_PATH "${HIP_PATH}/lib/cmake/hip" ${CMAKE_MODULE_PATH}) diff --git a/superbench/common/utils/device_manager.py b/superbench/common/utils/device_manager.py index 09398cac0..18bed8dce 100644 --- a/superbench/common/utils/device_manager.py +++ b/superbench/common/utils/device_manager.py @@ -13,7 +13,7 @@ if gpu.vendor == 'nvidia' or gpu.vendor == 'nvidia-graphics': import py3nvml.py3nvml as nvml elif gpu.vendor == 'amd' or gpu.vendor == 'amd-graphics': - from pyrsmi import rocml + import amdsmi as rocml class DeviceManager: @@ -150,7 +150,7 @@ def get_device_compute_capability(self): try: cap = nvml.nvmlDeviceGetCudaComputeCapability(self._device_handlers[0]) except Exception as err: - logger.error('Get device compute capability failed: {}'.format(str(err))) + logger.warning('Get device compute capability failed: {}'.format(str(err))) return None return cap @@ -166,7 +166,7 @@ def get_device_utilization(self, idx): try: util = nvml.nvmlDeviceGetUtilizationRates(self._device_handlers[idx]) except Exception as err: - logger.error('Get device utilization failed: {}'.format(str(err))) + logger.warning('Get device utilization failed: {}'.format(str(err))) return None return util.gpu @@ -182,7 +182,7 @@ def get_device_temperature(self, idx): try: temp = nvml.nvmlDeviceGetTemperature(self._device_handlers[idx], nvml.NVML_TEMPERATURE_GPU) except Exception as err: - logger.error('Get device temperature failed: {}'.format(str(err))) + logger.warning('Get device temperature failed: {}'.format(str(err))) temp = None return temp @@ -198,7 +198,7 @@ def get_device_power(self, idx): try: power = nvml.nvmlDeviceGetPowerUsage(self._device_handlers[idx]) except Exception as err: - logger.error('Get device power failed: {}'.format(str(err))) + logger.warning('Get device power failed: {}'.format(str(err))) return None return int(int(power) / 1000) @@ -214,7 +214,7 @@ def get_device_power_limit(self, idx): try: powerlimit = nvml.nvmlDeviceGetPowerManagementLimit(self._device_handlers[idx]) except Exception as err: - logger.error('Get device power limitation failed: {}'.format(str(err))) + logger.warning('Get device power limitation failed: {}'.format(str(err))) return None return int(int(powerlimit) / 1000) @@ -231,7 +231,7 @@ def get_device_memory(self, idx): try: mem = nvml.nvmlDeviceGetMemoryInfo(self._device_handlers[idx]) except Exception as err: - logger.error('Get device memory failed: {}'.format(str(err))) + logger.warning('Get device memory failed: {}'.format(str(err))) return None, None return mem.used, mem.total @@ -304,7 +304,7 @@ def get_device_ecc_error(self, idx): except nvml.NVMLError: pass except Exception as err: - logger.error('Get device ECC information failed: {}'.format(str(err))) + logger.warning('Get device ECC information failed: {}'.format(str(err))) return None, None try: @@ -316,7 +316,7 @@ def get_device_ecc_error(self, idx): except nvml.NVMLError: pass except Exception as err: - logger.error('Get device ECC information failed: {}'.format(str(err))) + logger.warning('Get device ECC information failed: {}'.format(str(err))) return None, None return corrected_ecc, uncorrected_ecc @@ -326,12 +326,13 @@ class AmdDeviceManager(DeviceManager): """Device management module for AMD.""" def __init__(self): """Constructor.""" - rocml.smi_initialize() + rocml.amdsmi_init() + self._device_handlers = rocml.amdsmi_get_processor_handles() super().__init__() def __del__(self): """Destructor.""" - rocml.smi_shutdown() + rocml.amdsmi_shut_down() def get_device_count(self): """Get the number of device. @@ -339,7 +340,7 @@ def get_device_count(self): Return: count (int): count of device. """ - return rocml.smi_get_device_count() + return len(self._device_handlers) def get_device_utilization(self, idx): """Get the utilization of device. @@ -351,11 +352,11 @@ def get_device_utilization(self, idx): util (int): the utilization of device, None means failed to get the data. """ try: - util = rocml.smi_get_device_utilization(idx) + engine_usage = rocml.amdsmi_get_gpu_activity(self._device_handlers[idx]) except Exception as err: - logger.error('Get device utilization failed: {}'.format(str(err))) + logger.warning('Get device utilization failed: {}'.format(str(err))) return None - return util + return engine_usage['gfx_activity'] def get_device_temperature(self, idx): """Get the temperature of device, unit: celsius. @@ -366,8 +367,16 @@ def get_device_temperature(self, idx): Return: temp (int): the temperature of device, None means failed to get the data. """ - # Currently no API provided in rocml. - return None + try: + temp = rocml.amdsmi_get_temp_metric( + self._device_handlers[idx], rocml.AmdSmiTemperatureType.EDGE, rocml.AmdSmiTemperatureMetric.CURRENT + ) + except (rocml.AmdSmiLibraryException, rocml.AmdSmiParameterException): + pass + except Exception as err: + logger.warning('Get device temperature failed: {}'.format(str(err))) + temp = None + return temp def get_device_power(self, idx): """Get the realtime power of device, unit: watt. @@ -379,11 +388,11 @@ def get_device_power(self, idx): temp (int): the realtime power of device, None means failed to get the data. """ try: - power = rocml.smi_get_device_average_power(idx) + power_measure = rocml.amdsmi_get_power_info(self._device_handlers[idx]) except Exception as err: - logger.error('Get device power failed: {}'.format(str(err))) + logger.warning('Get device power failed: {}'.format(str(err))) return None - return int(int(power) / 1000) + return int(power_measure['average_socket_power']) def get_device_power_limit(self, idx): """Get the power management limit of device, unit: watt. @@ -394,8 +403,12 @@ def get_device_power_limit(self, idx): Return: temp (int): the power management limit of device, None means failed to get the data. """ - # Currently no API provided in rocml. - return None + try: + power_measure = rocml.amdsmi_get_power_info(self._device_handlers[idx]) + except Exception as err: + logger.warning('Get device power limit failed: {}'.format(str(err))) + return None + return int(power_measure['power_limit']) def get_device_memory(self, idx): """Get the memory information of device, unit: byte. @@ -408,10 +421,10 @@ def get_device_memory(self, idx): total (int): the total device memory in bytes, None means failed to get the data. """ try: - mem_used = rocml.smi_get_device_memory_used(idx) - mem_total = rocml.smi_get_device_memory_total(idx) + mem_used = rocml.amdsmi_get_gpu_memory_usage(self._device_handlers[idx], rocml.AmdSmiMemoryType.VRAM) + mem_total = rocml.amdsmi_get_gpu_memory_total(self._device_handlers[idx], rocml.AmdSmiMemoryType.VRAM) except Exception as err: - logger.error('Get device memory failed: {}'.format(str(err))) + logger.warning('Get device memory failed: {}'.format(str(err))) return None, None return mem_used, mem_total @@ -425,8 +438,19 @@ def get_device_ecc_error(self, idx): corrected_ecc (int) : the count of single bit ecc error. uncorrected_ecc (int): the count of double bit ecc error. """ - # Currently no API provided in rocml. - return None, None + corrected_ecc = 0 + uncorrected_ecc = 0 + for block in rocml.AmdSmiGpuBlock: + try: + ecc_count = rocml.amdsmi_get_gpu_ecc_count(self._device_handlers[idx], block) + corrected_ecc += ecc_count['correctable_count'] + uncorrected_ecc += ecc_count['uncorrectable_count'] + except (rocml.AmdSmiLibraryException, rocml.AmdSmiParameterException): + pass + except Exception as err: + logger.info('Get device ECC information failed: {}'.format(str(err))) + + return corrected_ecc, uncorrected_ecc device_manager: Optional[DeviceManager] = DeviceManager() diff --git a/superbench/runner/playbooks/deploy.yaml b/superbench/runner/playbooks/deploy.yaml index 516d252b6..4830b97ad 100644 --- a/superbench/runner/playbooks/deploy.yaml +++ b/superbench/runner/playbooks/deploy.yaml @@ -100,7 +100,7 @@ docker run -itd --name={{ container }} \ --privileged --net=host --ipc=host \ {{ '--gpus=all' if nvidia_gpu_exist else '' }} \ - {{ '--security-opt seccomp=unconfined --group-add video' if amd_gpu_exist else '' }} \ + {{ '--security-opt seccomp=unconfined --group-add video --device=/dev/kfd --device=/dev/dri --cap-add=SYS_PTRACE --shm-size=16G' if amd_gpu_exist else '' }} \ -w /root -v {{ workspace }}:/root -v /mnt:/mnt \ -v /var/run/docker.sock:/var/run/docker.sock \ --entrypoint /bin/bash {{ docker_image }} && \ diff --git a/tests/benchmarks/micro_benchmarks/test_hipblaslt_function.py b/tests/benchmarks/micro_benchmarks/test_hipblaslt_function.py index f91019f69..98c693a67 100644 --- a/tests/benchmarks/micro_benchmarks/test_hipblaslt_function.py +++ b/tests/benchmarks/micro_benchmarks/test_hipblaslt_function.py @@ -102,7 +102,7 @@ def test_hipblaslt_gemm_result_parsing(self): self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code) self.assertEqual(2, len(benchmark.result)) - self.assertEqual(58624.5, benchmark.result['fp16_1_896_896_896_flops'][0]) + self.assertEqual(58.6245, benchmark.result['fp16_1_896_896_896_flops'][0]) # Negative case - invalid raw output self.assertFalse(benchmark._process_raw_result(1, 'HipBLAS API failed')) diff --git a/third_party/Makefile b/third_party/Makefile index d36152f79..b69259da2 100755 --- a/third_party/Makefile +++ b/third_party/Makefile @@ -12,13 +12,13 @@ CUDA_VER ?= $(shell nvcc --version | grep 'release' | awk '{print $$6}' | cut -c ROCBLAS_BRANCH ?= rocm-$(shell dpkg -l | grep 'rocm-dev ' | awk '{print $$3}' | cut -d '.' -f1-3) HIPBLASLT_BRANCH ?= rocm-$(shell dpkg -l | grep 'rocm-dev ' | awk '{print $$3}' | cut -d '.' -f1-3) -.PHONY: all cuda_with_msccl cuda rocm common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest cuda_msccl rocm_perftest fio rocm_rccl_tests rocm_rocblas rocm_bandwidthTest gpcnet cuda_gpuburn cpu_stream cpu_hpl directx_amf_encoding_latency directx_amd rocm_hipblaslt megatron_lm megatron_deepspeed +.PHONY: all cuda_with_msccl cuda rocm common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest cuda_msccl rocm_perftest fio rocm_rccl_tests rocm_rocblas rocm_bandwidthTest gpcnet cuda_gpuburn cpu_stream cpu_hpl directx_amf_encoding_latency directx_amd rocm_hipblaslt megatron_lm megatron_deepspeed apex_rocm # Build all targets. all: cuda rocm cuda_with_msccl: cuda cuda_msccl cuda: common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest gpcnet cuda_gpuburn megatron_lm megatron_deepspeed -rocm: common rocm_perftest rocm_rccl_tests rocm_rocblas rocm_bandwidthTest rocm_hipblaslt megatron_deepspeed +rocm: common rocm_perftest rocm_rccl_tests rocm_rocblas rocm_bandwidthTest rocm_hipblaslt megatron_deepspeed apex_rocm cpu: common cpu_perftest common: cpu_hpl cpu_stream fio directx_amd: directx_amf_encoding_latency @@ -86,11 +86,11 @@ ifneq (,$(wildcard fio/Makefile)) cd ./fio && ./configure --prefix=$(SB_MICRO_PATH) --disable-native && make -j && make install endif -# Build rccl-tests from commit 2a18737 of default branch. +# Build rccl-tests from commit 46375b1 of default branch. rocm_rccl_tests: sb_micro_path ifneq (, $(wildcard rccl-tests/Makefile)) - cd ./rccl-tests && make MPI=1 MPI_HOME=$(MPI_HOME) HIP_HOME=$(HIP_HOME) RCCL_HOME=$(RCCL_HOME) -j - cp -v ./rccl-tests/build/* $(SB_MICRO_PATH)/bin/ + cd ./rccl-tests && make MPI=1 MPI_HOME=$(MPI_HOME) -j + cp -v -r ./rccl-tests/build/* $(SB_MICRO_PATH)/bin/ endif # Build rocblas-bench. @@ -192,25 +192,45 @@ megatron_deepspeed: python -m pip install -r requirements.txt && \ python -m pip install DeepSpeed +# Instal apex of ROCm due to dependency of Megatron +apex_rocm: + $(eval TORCH_VERSION ?= $(shell python -c "import torch; print(torch.__version__)")) + $(eval TORCH_MAJOR_VERSION ?= $(word 1,$(subst ., ,$(TORCH_VERSION)))) + $(eval TORCH_MINOR_VERSION ?= $(word 2,$(subst ., ,$(TORCH_VERSION)))) + if [ ! -d "apex" ]; then \ + git clone https://github.com/ROCmSoftwarePlatform/apex.git ; \ + fi + cd apex && \ + if [ "$$(expr $(TORCH_MAJOR_VERSION) \> 2)" -eq 1 ] && [ "$$(expr $(TORCH_MINOR_VERSION) \> 1)" -eq 1 ]; then \ + git checkout master ; \ + elif [ "$$(expr $(TORCH_MAJOR_VERSION) == 2)" -eq 1 ] && [ "$$(expr $(TORCH_MINOR_VERSION) == 1)" -eq 1 ]; then \ + git checkout release/1.1.0 ; \ + elif [ "$$(expr $(TORCH_MAJOR_VERSION) == 2)" -eq 1 ] && [ "$$(expr $(TORCH_MINOR_VERSION) == 0)" -eq 1 ]; then \ + git checkout release/1.0.0 ; \ + elif [ "$$(expr $(TORCH_MAJOR_VERSION) == 1)" -eq 1 ]; then \ + git checkout release/1.0.0 ; \ + fi + pip install -v --disable-pip-version-check --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex + # Build MSCCL for CUDA cuda_msccl: sb_micro_path ifneq (,$(wildcard msccl/executor/msccl-executor-nccl/Makefile)) cd ./msccl/executor/msccl-executor-nccl && \ - make -j4 src.build && \ + make -j $(shell nproc --ignore=2) src.build && \ cd ../../.. mkdir -p $(SB_MICRO_PATH)/lib/msccl-executor-nccl && \ cp -r -v ./msccl/executor/msccl-executor-nccl/build/* $(SB_MICRO_PATH)/lib/msccl-executor-nccl/ endif ifneq (,$(wildcard msccl/scheduler/msccl-scheduler/Makefile)) cd ./msccl/scheduler/msccl-scheduler && \ - CXX=nvcc BIN_HOME=$(SB_MICRO_PATH)/lib/msccl-executor-nccl SRC_HOME=../../../msccl/executor/msccl-executor-nccl make -j4 && \ + CXX=nvcc BIN_HOME=$(SB_MICRO_PATH)/lib/msccl-executor-nccl SRC_HOME=../../../msccl/executor/msccl-executor-nccl make -j $(shell nproc --ignore=2) && \ cd ../../.. mkdir -p $(SB_MICRO_PATH)/lib/msccl-scheduler && \ cp -r -v ./msccl/scheduler/msccl-scheduler/build/* $(SB_MICRO_PATH)/lib/msccl-scheduler/ endif ifneq (,$(wildcard msccl/tests/msccl-tests-nccl/Makefile)) cd ./msccl/tests/msccl-tests-nccl && \ - make MPI=1 MPI_HOME=$(MPI_HOME) NCCL_HOME=$(SB_MICRO_PATH)/lib/msccl-executor-nccl -j4 && cd ../../.. + make MPI=1 MPI_HOME=$(MPI_HOME) NCCL_HOME=$(SB_MICRO_PATH)/lib/msccl-executor-nccl -j $(shell nproc --ignore=2) && cd ../../.. mkdir -p $(SB_MICRO_PATH)/bin/msccl-tests-nccl && \ cp -r -v ./msccl/tests/msccl-tests-nccl/build/* $(SB_MICRO_PATH)/bin/msccl-tests-nccl/ endif diff --git a/third_party/rccl-tests b/third_party/rccl-tests index 2a18737dc..46375b1c5 160000 --- a/third_party/rccl-tests +++ b/third_party/rccl-tests @@ -1 +1 @@ -Subproject commit 2a18737dc681e03ce82c046caa71b28db65017b5 +Subproject commit 46375b1c527b2e3afe80fdd6dd136151bd939675