Skip to content

Commit

Permalink
Fix bug with eye, and also zero workspace before LU factorization (#807)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Nov 21, 2024
1 parent ab409ce commit b3869e0
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions include/matx/transforms/inverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ class matxInversePlan_t {
else if (backend == MatInverseLUBackend::cuSolverGetRfRs) {
MATX_ASSERT_STR(params.batch_size == 1, matxInvalidParameter, "cuSolverGetRfRs backend only used for single batches");

// cuSolver has a bug that requires this workspace to be zeroed each time
cudaMemsetAsync(d_workspace, 0, this->dspace, stream);

[[maybe_unused]] cusolverStatus_t solver_ret;
solver_ret = cusolverDnXgetrf(
cusolver_handle,
Expand All @@ -492,10 +495,10 @@ class matxInversePlan_t {
if (h_info[i] != 0) {
MATX_THROW(matxLUError, "inverse failed");
}
}
}

// We're Solving Ax = b, so setting "b" to the identity matrix will give us A^-1
(a_inv = eye()).run(stream);
(a_inv = eye<typename TensorTypeA::value_type, 2>({params.n, params.n})).run(stream);

solver_ret = cusolverDnXgetrs(
cusolver_handle,
Expand Down

0 comments on commit b3869e0

Please sign in to comment.