Skip to content

Commit

Permalink
Fixed inverse from stomping on input
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick committed Aug 2, 2023
1 parent 3eb0635 commit 3bc8548
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion include/matx/generators/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ template <typename T, int RANK> class randomTensorView_t {
}
else if constexpr (is_single_thread_host_executor_v<Executor>) {
if (!init_) {
curandStatus_t ret;
[[maybe_unused]] curandStatus_t ret;

ret = curandCreateGeneratorHost(&gen_, CURAND_RNG_PSEUDO_MT19937);
MATX_ASSERT_STR_EXP(ret, CURAND_STATUS_SUCCESS, matxCudaError, "Failed to create random number generator");
Expand Down
19 changes: 13 additions & 6 deletions include/matx/transforms/inverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ class matxInversePlan_t {
* Inverse of A (if it exists)
*
*/
matxInversePlan_t(TensorTypeAInv &a_inv, const TensorTypeA &a)
matxInversePlan_t(TensorTypeAInv &a_inv, const TensorTypeA &a, cudaStream_t stream)
{
static_assert(RANK >= 2);

MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

stream_ = stream;

// Ok to remove since we're just passing a list of RO pointers
//using a_nc = typename std::remove_const<decltype(a)>(a);
Expand All @@ -123,26 +125,29 @@ class matxInversePlan_t {
// here as our batch dims
std::vector<const T1 *> in_pointers;
std::vector<T1 *> out_pointers;
make_tensor(tmp_a_, a.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
(tmp_a_ = a).run(stream);

if constexpr (RANK == 2) {
in_pointers.push_back(&a(0, 0));
in_pointers.push_back(&tmp_a_(0, 0));
out_pointers.push_back(&a_inv(0, 0));
}
else {
using shape_type = typename TensorTypeA::desc_type::shape_type;
int batch_offset = 2;
std::array<shape_type, TensorTypeA::Rank()> idx{0};
auto a_shape = a.Shape();
auto a_shape = tmp_a_.Shape();
// Get total number of batches
size_t total_iter = std::accumulate(a_shape.begin(), a_shape.begin() + TensorTypeA::Rank() - batch_offset, 1, std::multiplies<shape_type>());
for (size_t iter = 0; iter < total_iter; iter++) {
auto ip = std::apply([&a](auto... param) { return a.GetPointer(param...); }, idx);
auto ip = std::apply([&](auto... param) { return tmp_a_.GetPointer(param...); }, idx);
auto op = std::apply([&a_inv](auto... param) { return a_inv.GetPointer(param...); }, idx);

in_pointers.push_back(ip);
out_pointers.push_back(op);

// Update all but the last 2 indices
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a, idx, batch_offset);
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(tmp_a_, idx, batch_offset);
}
}

Expand Down Expand Up @@ -307,6 +312,8 @@ class matxInversePlan_t {
int *d_info;
T1 **d_A_array;
T1 **d_A_inv_array;
cudaStream_t stream_;
matx::tensor_t<typename TensorTypeA::scalar_type, TensorTypeA::Rank()> tmp_a_;
};

/**
Expand Down Expand Up @@ -367,7 +374,7 @@ void inv_impl(TensorTypeAInv &a_inv, const TensorTypeA &a,
// Get cache or new inverse plan if it doesn't exist
auto ret = detail::inv_cache.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxInversePlan_t{a_inv, a};
auto tmp = new detail::matxInversePlan_t{a_inv, a, stream};
detail::inv_cache.Insert(params, static_cast<void *>(tmp));
tmp->Exec(stream);
}
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/resample_poly.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ inline void resample_poly_impl(OutType &out, const InType &in, const FilterType
}

const index_t up_size = in.Size(RANK-1) * up;
const index_t outlen = up_size / down + ((up_size % down) ? 1 : 0);
[[maybe_unused]] const index_t outlen = up_size / down + ((up_size % down) ? 1 : 0);

MATX_ASSERT_STR(out.Size(RANK-1) == outlen, matxInvalidDim, "resample_poly: output size mismatch");

Expand Down

0 comments on commit 3bc8548

Please sign in to comment.