Skip to content

Commit

Permalink
Add numerical checking helper to Level 3 functions (#1238)
Browse files Browse the repository at this point in the history
* Add numerical checking helper to Level 3 rocBLAS

* Added check to see if the input is const

* Enclosed the kernel function of TRSM with brackets to invoke the destructor and release the handle memory

* Addressed the comments
  • Loading branch information
NaveenElumalaiAMD authored May 11, 2022
1 parent 1585f4d commit be030fe
Show file tree
Hide file tree
Showing 49 changed files with 3,545 additions and 794 deletions.
93 changes: 74 additions & 19 deletions library/src/blas3/rocblas_dgmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ namespace

RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle);

auto layer_mode = handle->layer_mode;
auto layer_mode = handle->layer_mode;
auto check_numerics = handle->check_numerics;

if(layer_mode
& (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench
Expand Down Expand Up @@ -113,25 +114,79 @@ namespace
static constexpr rocblas_int batch_count = 1;
static constexpr rocblas_stride stride_A = 0, stride_x = 0, stride_C = 0;

return rocblas_dgmm_template(handle,
side,
m,
n,
A,
offset_A,
lda,
stride_A,
x,
offset_x,
incx,
stride_x,
C,
offset_C,
ldc,
stride_C,
batch_count);
}
if(check_numerics)
{
bool is_input = true;
rocblas_status dgmm_check_numerics_status
= rocblas_dgmm_check_numerics(rocblas_dgmm_name<T>,
handle,
side,
m,
n,
A,
lda,
stride_A,
x,
incx,
stride_x,
C,
ldc,
stride_C,
batch_count,
check_numerics,
is_input);
if(dgmm_check_numerics_status != rocblas_status_success)
return dgmm_check_numerics_status;
}

rocblas_status status = rocblas_status_success;
status = rocblas_dgmm_template(handle,
side,
m,
n,
A,
offset_A,
lda,
stride_A,
x,
offset_x,
incx,
stride_x,
C,
offset_C,
ldc,
stride_C,
batch_count);

if(status != rocblas_status_success)
return status;

if(check_numerics)
{
bool is_input = false;
rocblas_status dgmm_check_numerics_status
= rocblas_dgmm_check_numerics(rocblas_dgmm_name<T>,
handle,
side,
m,
n,
A,
lda,
stride_A,
x,
incx,
stride_x,
C,
ldc,
stride_C,
batch_count,
check_numerics,
is_input);
if(dgmm_check_numerics_status != rocblas_status_success)
return dgmm_check_numerics_status;
}
return status;
}
} // namespace

/*
Expand Down
21 changes: 21 additions & 0 deletions library/src/blas3/rocblas_dgmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* ************************************************************************ */

#pragma once
#include "check_numerics_matrix.hpp"
#include "check_numerics_vector.hpp"
#include "handle.hpp"

/**
Expand All @@ -47,3 +49,22 @@ rocblas_status rocblas_dgmm_template(rocblas_handle handle,
rocblas_int ldc,
rocblas_stride stride_c,
rocblas_int batch_count);

template <typename TConstPtr, typename TPtr>
rocblas_status rocblas_dgmm_check_numerics(const char* function_name,
rocblas_handle handle,
rocblas_side side,
rocblas_int m,
rocblas_int n,
TConstPtr A,
rocblas_int lda,
rocblas_stride stride_A,
TConstPtr x,
rocblas_int incx,
rocblas_stride stride_x,
TPtr C,
rocblas_int ldc,
rocblas_stride stride_c,
rocblas_int batch_count,
const int check_numerics,
bool is_input);
92 changes: 74 additions & 18 deletions library/src/blas3/rocblas_dgmm_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ namespace

RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle);

auto layer_mode = handle->layer_mode;
auto layer_mode = handle->layer_mode;
auto check_numerics = handle->check_numerics;

if(layer_mode
& (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench
Expand Down Expand Up @@ -129,23 +130,78 @@ namespace
static constexpr rocblas_stride offset_a = 0, offset_x = 0, offset_c = 0;
static constexpr rocblas_stride stride_a = 0, stride_x = 0, stride_c = 0;

return rocblas_dgmm_template(handle,
side,
m,
n,
A,
offset_a,
lda,
stride_a,
x,
offset_x,
incx,
stride_x,
C,
offset_c,
ldc,
stride_c,
batch_count);
if(check_numerics)
{
bool is_input = true;
rocblas_status dgmm_check_numerics_status
= rocblas_dgmm_check_numerics(rocblas_dgmm_batched_name<T>,
handle,
side,
m,
n,
A,
lda,
stride_a,
x,
incx,
stride_x,
C,
ldc,
stride_c,
batch_count,
check_numerics,
is_input);
if(dgmm_check_numerics_status != rocblas_status_success)
return dgmm_check_numerics_status;
}

rocblas_status status = rocblas_status_success;
status = rocblas_dgmm_template(handle,
side,
m,
n,
A,
offset_a,
lda,
stride_a,
x,
offset_x,
incx,
stride_x,
C,
offset_c,
ldc,
stride_c,
batch_count);

if(status != rocblas_status_success)
return status;

if(check_numerics)
{
bool is_input = false;
rocblas_status dgmm_check_numerics_status
= rocblas_dgmm_check_numerics(rocblas_dgmm_batched_name<T>,
handle,
side,
m,
n,
A,
lda,
stride_a,
x,
incx,
stride_x,
C,
ldc,
stride_c,
batch_count,
check_numerics,
is_input);
if(dgmm_check_numerics_status != rocblas_status_success)
return dgmm_check_numerics_status;
}
return status;
}

} // namespace
Expand Down
114 changes: 114 additions & 0 deletions library/src/blas3/rocblas_dgmm_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,81 @@ rocblas_status rocblas_dgmm_template(rocblas_handle handle,
}
return rocblas_status_success;
}

template <typename TConstPtr, typename TPtr>
rocblas_status rocblas_dgmm_check_numerics(const char* function_name,
rocblas_handle handle,
rocblas_side side,
rocblas_int m,
rocblas_int n,
TConstPtr A,
rocblas_int lda,
rocblas_stride stride_a,
TConstPtr x,
rocblas_int incx,
rocblas_stride stride_x,
TPtr C,
rocblas_int ldc,
rocblas_stride stride_c,
rocblas_int batch_count,
const int check_numerics,
bool is_input)
{

rocblas_status check_numerics_status = rocblas_status_success;
if(is_input)
{
rocblas_int dim_x = (side == rocblas_side_left) ? m : n;
check_numerics_status
= rocblas_internal_check_numerics_matrix_template(function_name,
handle,
rocblas_operation_none,
rocblas_fill_full,
rocblas_client_general_matrix,
m,
n,
A,
0,
lda,
stride_a,
batch_count,
check_numerics,
is_input);
if(check_numerics_status != rocblas_status_success)
return check_numerics_status;

check_numerics_status = rocblas_internal_check_numerics_vector_template(function_name,
handle,
dim_x,
x,
0,
incx,
stride_x,
batch_count,
check_numerics,
is_input);
if(check_numerics_status != rocblas_status_success)
return check_numerics_status;
}
check_numerics_status
= rocblas_internal_check_numerics_matrix_template(function_name,
handle,
rocblas_operation_none,
rocblas_fill_full,
rocblas_client_general_matrix,
m,
n,
C,
0,
ldc,
stride_c,
batch_count,
check_numerics,
is_input);

return check_numerics_status;
}

// Instantiations below will need to be manually updated to match any change in
// template parameters in the files dgmm*.cpp

Expand Down Expand Up @@ -198,4 +273,43 @@ INSTANTIATE_DGMM_TEMPLATE(double const* const*, double* const*)
INSTANTIATE_DGMM_TEMPLATE( rocblas_float_complex const* const*, rocblas_float_complex* const*)
INSTANTIATE_DGMM_TEMPLATE(rocblas_double_complex const* const*, rocblas_double_complex* const*)
#undef INSTANTIATE_DGMM_TEMPLATE


#ifdef INSTANTIATE_DGMM_NUMERICS
#error INSTANTIATE_DGMM_NUMERICS already defined
#endif

#define INSTANTIATE_DGMM_NUMERICS(TConstPtr_, TPtr_) \
template rocblas_status rocblas_dgmm_check_numerics<TConstPtr_, TPtr_> \
(const char* function_name, \
rocblas_handle handle, \
rocblas_side side, \
rocblas_int m, \
rocblas_int n, \
TConstPtr_ A, \
rocblas_int lda, \
rocblas_stride stride_a, \
TConstPtr_ x, \
rocblas_int inc, \
rocblas_stride stride_x, \
TPtr_ C, \
rocblas_int ldc, \
rocblas_stride stride_c, \
rocblas_int batch_count, \
const int check_numerics, \
bool is_input);

// instantiate for rocblas_Xdgmm and rocblas_Xdgmm_strided_batched
INSTANTIATE_DGMM_NUMERICS(float const*, float*)
INSTANTIATE_DGMM_NUMERICS(double const*, double*)
INSTANTIATE_DGMM_NUMERICS(rocblas_float_complex const*, rocblas_float_complex*)
INSTANTIATE_DGMM_NUMERICS(rocblas_double_complex const*, rocblas_double_complex*)

// instantiate for rocblas_Xdgmm_batched
INSTANTIATE_DGMM_NUMERICS(float const* const*, float* const*)
INSTANTIATE_DGMM_NUMERICS(double const* const*, double* const*)
INSTANTIATE_DGMM_NUMERICS(rocblas_float_complex const* const*, rocblas_float_complex* const*)
INSTANTIATE_DGMM_NUMERICS(rocblas_double_complex const* const*, rocblas_double_complex* const*)

#undef INSTANTIATE_DGMM_NUMERICS
// clang-format on
Loading

0 comments on commit be030fe

Please sign in to comment.