Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] MatX is around x15 slower than CuPy for the same task #688

Open
HugoPhibbs opened this issue Aug 3, 2024 · 71 comments · Fixed by #703
Open

[QST] MatX is around x15 slower than CuPy for the same task #688

HugoPhibbs opened this issue Aug 3, 2024 · 71 comments · Fixed by #703

Comments

@HugoPhibbs
Copy link
Contributor

HugoPhibbs commented Aug 3, 2024

Gidday.

I'm a bit of a novice with MatX and CPP, and was looking to get some help with optimising my MatX code.

So basically I'm trying to refactor my code that was written in CuPy first into lightning fast MatX code. Except I find that my MatX implementation, despite (what looks to me) an identical equivalent to my CuPy code, it is a lot slower. I was wondering if anybody would be able to give me some tips as to where my code might be slowing down.

FYI a general assumption is that MatX's operators are super lightweight - so the reshapes, repmats are all super quick.

My MatX code looks like:

matx::tensor_t<matx::matxFp16, 2>  GsDBSCAN::findDistancesMatX(matx::tensor_t<matx::matxFp16, 2> &X_t, matx::tensor_t<int, 2> &A_t, matx::tensor_t<int, 2> &B_t, float alpha, int batchSize) {
    const int k = A_t.Shape()[1] / 2;
    const int m = B_t.Shape()[1];

    const int n = X_t.Shape()[0];
    const int d = X_t.Shape()[1];
    int D = B_t.Shape()[0] / 2;

    batchSize = (batchSize != -1) ? batchSize: GsDBSCAN::findDistanceBatchSize(alpha, n, d, k, m);

    auto AFlat_t = matx::flatten(A_t);

    auto distances_t = matx::make_tensor<matx::matxFp16>({n, 2*k*m});

    for (int i = 0; i < n; i += batchSize) {
        int maxBatchIdx = i + batchSize - 1; // Index within X along the ROWS

        auto XSubset_t_op = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd});

        auto ABatchFlat_t_op = matx::slice(AFlat_t, {i * 2 * k}, {(maxBatchIdx + 1) * 2 * k});

        auto BBatch_t_op = matx::remap<0>(B_t, ABatchFlat_t_op);

        auto XBatch_t_op = matx::remap<0>(X_t, matx::flatten(BBatch_t_op));

        auto XBatchReshaped_t_op = matx::reshape(XBatch_t_op, {batchSize, 2*k*m, d});

        auto XSubsetReshaped_t_op = matx::reshape(XSubset_t_op, {batchSize, 1, d});

        auto YBatch_t_op = (XBatchReshaped_t_op - matx::repmat(XSubsetReshaped_t_op, {1, 2*k*m, 1})); // Repmat is a workaround for minusing naively incompatibhle tensor shapes

        auto YBatch_t_norm_op = matx::vector_norm(YBatch_t_op, {2}, matx::NormOrder::L2);

        (matx::slice(distances_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd}) = YBatch_t_norm_op).run();
    }

    return distances_t;
}

And the same CuPy code looks like:

def find_distances(X, A, B, alpha=1.2, batch_size = -1):
    k = A.shape[1] // 2
    m = B.shape[1]

    n = X.shape[0]
    d = X.shape[1]
    D = B.shape[0] // 2

    batch_size = batch_size if batch_size != -1 else get_batch_size(n, d, k, m, alpha=alpha)

    distances = cp.empty(shape=(n, 2 * k * m),
                         dtype=cp.float16)  # float32 causes a memory overload. float16 is fine (for eps 2DP)

    for i in range(0, n, batch_size):
        max_batch_idx = min(i + batch_size, X.shape[0])

        Z_batch = X[B[A[i:max_batch_idx]]]
      
        # (Edit): Changed the reshape call to be a little clearer. Z_batch_adj is equivalent to XBatchReshaped_t_op above.
        Z_batch_adj = Z_batch.reshape(batch_size, 2 * k * m,  d)

        Y_batch = Z_batch_adj - X[i:max_batch_idx, cp.newaxis, :]

        distances[i:max_batch_idx] = cp.linalg.norm(Y_batch, axis=2)

    return distances

The parameters used for both are:

n = 70_000
k = 5
m = 50
d = 784
D = 1024
batchSize ~= 250 (FYI it will should always be a divisor of n, I found that CuPy implementation was a lot slower otherwise on the final iteration).

Regarding results, the MatX code takes around 14.5 seconds to complete, but CuPy takes 0.9 seconds (including Cuda Synchronisations).

As a baseline, a multithreaded (64 threads) CPU implementation of the above code (using loops with no tensors involved) takes less than 0.7 seconds. A single threaded CPU implementation takes around 7 seconds - (this is using the same machine of course).

Sorry if the variable names are a little cryptic.

I've tested for around n = 1000 and found that the two implementations produce the same results (albeit with a small amount of floating point errors).

Thanks in advance.

@cliffburdick
Copy link
Collaborator

Hi @HugoPhibbs , this is very interesting an unexpected. We'll take a look at the profile and get back to you.

@luitjens
Copy link
Collaborator

luitjens commented Aug 3, 2024 via email

@HugoPhibbs
Copy link
Contributor Author

Hi @luitjens

I'm using batches because otherwise, my GPU quickly runs out of memory (I tried no batching with CuPy and this was the result). Batching is used to control the memory usage of intermediary tensors.

I intend in the future to tune the batch size to produce optimal memory usage of the GPU, but right now, I'm focused on getting an MVP.

@luitjens
Copy link
Collaborator

luitjens commented Aug 3, 2024 via email

@HugoPhibbs
Copy link
Contributor Author

Hi @luitjens, thx for getting back to me.

I'm timing the complete function runtime - as in how long it takes to run the function start to finish. The timing looks a bit like this:

TEST_F(TestFindingDistances, TestLargeInputMatX) {
    int k = 5;
    int n = 70000;
    int m = 50;
    int D = 1024;
    int d = 784;

    auto A = tu::createMockAMatrixMatX(n, k, D);
    auto B = tu::createMockBMatrixMatX(n, m, D);
    auto X = tu::createMockMnistDatasetMatX(n, d);

    cudaDeviceSynchronize(); // Possibly not necessary?

    tu::Time start = tu::timeNow();

    auto distances = GsDBSCAN::findDistancesMatX(X, A, B, 1.2, 250);
    cudaDeviceSynchronize();

    tu::printDurationSinceStart(start); 

    printf("%lld %lld", distances.Shape()[0], distances.Shape()[1]);

    ASSERT_TRUE(distances.Shape()[0] == n);
    ASSERT_TRUE(distances.Shape()[1] == 2*k*m);
}

As for memory options, I changed the memory space of all the tensors to matx::MATX_DEVICE_MEMORY and I'm still getting the same 14.5 second runtime. E.g. what I did was the below for all the make_tensor calls:

inline auto createMockAMatrixMatX(int n = 70000, int k = 2, int D = 1024) {
    auto A = matx::make_tensor<float>({n, 2*k}, matx::MATX_DEVICE_MEMORY);
    auto A_i = matx::make_tensor<int32_t>({n, 2*k}, matx::MATX_DEVICE_MEMORY);

    int a = 2 * (D - 1);

    (A = matx::random<float>({n, 2*k}, matx::UNIFORM, 0, a)).run();
    (A_i = matx::as_type<int32_t>(A)).run();

    return A_i;
}

@HugoPhibbs
Copy link
Contributor Author

Hi again,

I've done some more testing, and I've found that the cuda synchronise step takes the lion's share of the runtime. I added some hacky profiling to the function that looks like this:

matx::tensor_t<matx::matxFp16, 2>  GsDBSCAN::findDistancesMatX(matx::tensor_t<matx::matxFp16, 2> &X_t, matx::tensor_t<int, 2> &A_t, matx::tensor_t<int, 2> &B_t, float alpha, int batchSize) {
    const int k = A_t.Shape()[1] / 2;
    const int m = B_t.Shape()[1];

    const int n = X_t.Shape()[0];
    const int d = X_t.Shape()[1];
    int D = B_t.Shape()[0] / 2;

    batchSize = (batchSize != -1) ? batchSize : GsDBSCAN::findDistanceBatchSize(alpha, n, d, k, m);

    auto AFlat_t = matx::flatten(A_t);

    auto distances_t = matx::make_tensor<matx::matxFp16>({n, 2*k*m}, matx::MATX_DEVICE_MEMORY);

    int j = 0;
    std::vector<double> times;

    auto start_all = std::chrono::high_resolution_clock::now();

    for (int i = 0; i < n; i += batchSize) {
        auto start = std::chrono::high_resolution_clock::now();

        int maxBatchIdx = i + batchSize - 1; // Index within X along the ROWS

        auto XSubset_t_op = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd});

        auto ABatchFlat_t_op = matx::slice(AFlat_t, {i * 2 * k}, {(maxBatchIdx + 1) * 2 * k});

        auto BBatch_t_op = matx::remap<0>(B_t, ABatchFlat_t_op);

        auto XBatch_t_op = matx::remap<0>(X_t, matx::flatten(BBatch_t_op));

        auto XBatchReshaped_t_op = matx::reshape(XBatch_t_op, {batchSize, 2*k*m, d});

        auto XSubsetReshaped_t_op = matx::reshape(XSubset_t_op, {batchSize, 1, d});

        auto YBatch_t_op = (XBatchReshaped_t_op - matx::repmat(XSubsetReshaped_t_op, {1, 2*k*m, 1})); // Repmat is a workaround for minusing naively incompatibhle tensor shapes

        auto YBatch_t_norm_op = matx::vector_norm(YBatch_t_op, {2}, matx::NormOrder::L2);

        (matx::slice(distances_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd}) = YBatch_t_norm_op).run();

        // Record end time
        auto end = std::chrono::high_resolution_clock::now();

        // Calculate the duration
        std::chrono::duration<double> duration = end - start;

        // Cast to double and store in array
        times.push_back(duration.count());
    }

    auto start_sync = std::chrono::high_resolution_clock::now();

    cudaDeviceSynchronize();

    // Record end time
    auto end_sync = std::chrono::high_resolution_clock::now();

    // Calculate the duration
    std::chrono::duration<double> duration_sync = end_sync - start_sync;

    // Output the duration
    std::cout << "Time taken: " << duration_sync.count() << " seconds" << std::endl;

    for (const auto& element : times) {
        std::cout << element << std::endl;
    }

    // Record end time
    auto end_all = std::chrono::high_resolution_clock::now();

    // Calculate the duration
    std::chrono::duration<double> duration = end_all - start_all;

    // Output the duration
    std::cout << "Time taken: " << duration.count() << " seconds" << std::endl;

    return distances_t;
}

Which produces the output:

Time taken: 14.4069 seconds // For the synchronise
0.00528887
1.775e-05
... A bunch more times around 1.5e-5 (with approx 300 total loop runs, this creates a runtime of around 0.005 ~ 300 * 1.5e-5 seconds for the entire loop (which is very quick)
1.642e-05
Time taken: 14.4189 seconds // For the overall function call

Has this got something to do with the fact that MatX looks to have an async execution style? I.e. adding a bunch of async operations to queue on GPU may produce a large bottleneck effect? - Just an idea

@luitjens
Copy link
Collaborator

luitjens commented Aug 5, 2024

Hi, can you please provide fully buildable/runnable example in both matx and python that we can use to compare?

Generally speaking you don't want to include allocation time in your timings as you want to allocate once upfront and reuse.

@luitjens
Copy link
Collaborator

luitjens commented Aug 5, 2024

alternatively if you cannot easily create us a standalone reproducer can you share an nsys profile of both python and matx with us?

@HugoPhibbs
Copy link
Contributor Author

HugoPhibbs commented Aug 6, 2024

Ok thx, pls see the gist: https://gist.github.com/HugoPhibbs/a2ce2c75b70c6737f1094f32b15af3ea

It contains source files to run it, along with an nsys profile

@luitjens
Copy link
Collaborator

luitjens commented Aug 6, 2024

I recreated your repro as an example and had to make a few modifications to get it to build. Once I did that I ran on H100 and I see this output:

Total Time taken: 0.0242754 seconds
Total Time taken (again): 0.0248659 seconds
70000 500

Unfortunately I was not able to view your profile as it says it is corrupt. Could you get a fresh profile, put it in a zip and upload it to your example?

@luitjens
Copy link
Collaborator

luitjens commented Aug 6, 2024

Hugo I created a repro with some build fixes here: https://github.com/NVIDIA/MatX/tree/688-repro

From your build directory:
%> make repro
%> ./examples/repro

Can you verify that the issue still reproduces?

@luitjens
Copy link
Collaborator

luitjens commented Aug 6, 2024

On L40 i see similar performance:

Total Time taken: 0.0205341 seconds
Total Time taken (again): 0.0211705 seconds

@HugoPhibbs
Copy link
Contributor Author

HugoPhibbs commented Aug 6, 2024

Ok thanks. Honestly I'm a little bit skeptical that it would take just a fraction of a second. But yes, the error still reproduces on my machine:

make repro
[ 50%] Building CUDA object examples/CMakeFiles/repro.dir/repro.cu.o
[100%] Linking CUDA executable repro
[100%] Built target repro

./examples/repro
Sync Time taken: 14.3653 seconds
0.00239233
2.32e-05
....
1.776e-05
Total Time taken: 14.3747 seconds
Total Time taken (again): 14.375 seconds
70000 500

Please see this zip for the profile test_profile.zip - may have been an encoding issue.

I guess now would be a good time to show you my environment:

nvidia-smi
Wed Aug  7 09:38:27 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.28.03              Driver Version: 560.28.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
| 53%   45C    P5             63W /  390W |    1022MiB /  24576MiB |     47%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        Off |   00000000:4A:00.0  On |                  N/A |
|  0%   48C    P8             49W /  390W |     115MiB /  24576MiB |     25%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

@luitjens
Copy link
Collaborator

luitjens commented Aug 6, 2024

Thank you for the profile. When I inspect the profiles the matx generated kernels seem inline with hardware (2ms on H100 and 6.5ms on 3090). However, the reduction kernel seems way off. We use cub for this kernel so perhaps there is something going wrong in cub. We will investigate this. As a work around can you try to materialize the inputs to the reduction kernel into a memory backed tensor then compute the vector norm on the memory backed tensor: https://github.com/NVIDIA/MatX/blob/688-repro/examples/repro.cu#L96

@luitjens
Copy link
Collaborator

luitjens commented Aug 6, 2024

can you also get me an ncu profile with this command on your system:

ncu --import-source --set full --metrics all --kernel-id "::regex:.*ReduceKernel.*:1" -o 3090 ./examples/repro

Then zip up 309.ncu-rep and attach that too.

@luitjens
Copy link
Collaborator

luitjens commented Aug 7, 2024

updated ncu instruction above

@luitjens
Copy link
Collaborator

luitjens commented Aug 7, 2024

Also can you try updating your toolkit?

You currently have: Cuda 11.8.

I'd suggest going to 12.5.

@cliffburdick
Copy link
Collaborator

@HugoPhibbs I ran this on both an A100 and 3090. Here are the results:
A100:

Total Time taken: 0.0451949 seconds
Total Time taken (again): 0.0453394 seconds

3090:

Total Time taken: 0.0272116 seconds
Total Time taken (again): 0.0280414 seconds

This is CUDA 12.5. I will try 11.8 and report back.

@cliffburdick
Copy link
Collaborator

@HugoPhibbs on my nsys capture I see 32 registers per thread whereas @luitjens pointed out you had 128. Here is our compilation line:

cd /repro/tmp/MatX/build/examples && /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DMATX_DISABLE_CUB_CACHE -DMATX_ENABLE_FILEIO -DMATX_ENABLE_PYBIND11 -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_CUDA -DTHRUST_HOST_SYSTEM=THRUST_HOST_SYSTEM_CPP -I/repro/tmp/MatX/include -I/repro/tmp/MatX/include/matx/kernels -I/repro/tmp/MatX/build/_deps/cccl-src/thrust/thrust/cmake/../.. -I/repro/tmp/MatX/build/_deps/cccl-src/libcudacxx/lib/cmake/libcudacxx/../../../include -I/repro/tmp/MatX/build/_deps/cccl-src/cub/cub/cmake/../.. -isystem=/repro/tmp/MatX/build/_deps/pybind11-src/include -isystem=/usr/include/python3.10 -isystem=/usr/local/cuda/include -g --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Wall -Wextra -Wcast-align -Wunused -Wshadow -Wno-unknown-pragmas -Wnon-virtual-dtor -Wconversion -Wmisleading-indentation -Wduplicated-cond -Wduplicated-branches -Wlogical-op -Wnull-dereference -Werror all-warnings --threads 0 -ftemplate-backtrace-limit=0 -lineinfo --expt-relaxed-constexpr -DMATX_ROOT=\"/repro/tmp/MatX\" -fvisibility=hidden -MD -MT examples/CMakeFiles/repro.dir/repro.cu.o -MF CMakeFiles/repro.dir/repro.cu.o.d -x cu -c /repro/tmp/MatX/examples/repro.cu -o CMakeFiles/repro.dir/repro.cu.o

Can you send what yours looks like? If you're importing the matx::matx target it should look similar

@cliffburdick
Copy link
Collaborator

@HugoPhibbs I was able to reproduce your issue on CUDA 11.8 with everything else the same:

Total Time taken: 14.6827 seconds
Total Time taken (again): 14.6836 seconds

Is it possible for you to update? This may be an issue where nvcc had trouble with register reuse in this case causing poor occupancy.

@HugoPhibbs
Copy link
Contributor Author

thx @cliffburdick and @luitjens

@luitjens re ncu, currently waiting for admin permissions to run sudo ncu ..., I'll send results once I can.

As on the front of upgrading CUDA, I upgraded to 12.5, it runs ok, but now my tests are broken 🙃.

Just to make sure, when I do cudaDeviceSynchronize() this makes sure that any pending operations on the GPU are done right? When I upgrade to 12.5, the returned distances_t tensor is now just empty (full of zeros) - where as with 11.8 it was full of values.

E.g. my simple tests look a bit like:

auto distances_t = GsDBSCAN::findDistancesMatX(X_t_16, A_t, B_t);

cudaDeviceSynchronize();

matx::matxFp16 *distances_ptr = distances_t.Data();

matx::matxFp16 expected_squared[] = {
        11, 5, 14, 11, 0, 5,
        9, 0, 11, 0, 14, 11,
        5, 0, 5, 5, 8, 14,
        9, 5, 0, 0, 9, 5,
        9, 6, 5, 5, 0, 6
};

for (int i = 0; i < 5*6; i++) {
    ASSERT_NEAR(std::sqrt(expected_squared[i]), distances_ptr[i], 1e-3); // distances is full of zeros with 12.5 but actually full in 11.8
}

Do you guys know a reason why this may be?

@luitjens
Copy link
Collaborator

luitjens commented Aug 7, 2024 via email

@HugoPhibbs
Copy link
Contributor Author

@luitjens yep added the macro and no errors occur

@luitjens
Copy link
Collaborator

luitjens commented Aug 7, 2024 via email

@HugoPhibbs
Copy link
Contributor Author

sure will do

@luitjens
Copy link
Collaborator

luitjens commented Aug 7, 2024 via email

@HugoPhibbs
Copy link
Contributor Author

Ok have checked. But don't think is the case since I'm using managed memory?

I was getting seg fault when using device memory, so I changed the mem to managed and it worked (in cuda 11.8)

@luitjens
Copy link
Collaborator

luitjens commented Aug 8, 2024 via email

@cliffburdick
Copy link
Collaborator

@HugoPhibbs we're still looking into it, but we can reproduce your issue.

@HugoPhibbs
Copy link
Contributor Author

HugoPhibbs commented Aug 11, 2024

@luitjens @cliffburdick have tried a fresh rebuild and using floats instead of halves. Behaviour is more or less the same.

I'm getting strange behaviour when I run it a .cpp file or a .cu file - 14 seconds or so on GTest, but 0.2 secs when I reproduce the same thing as an example in a .cu file

This is the file I'm running: https://gist.github.com/HugoPhibbs/1bfd7180119040186b57b515dff4f69d

That file was run from the examples directory of the MatX project. I compiled it and ran it with make repro2 and ./examples/repro2. This produces a runtime of 0.2 secs, but the whole distances tensor has zeros all thru it. I thought it might be a memory space issue, but all the tensors are being run in managed memory, so I don't think its that.

I've commented out the print loop. Can you guys check if it has the same behavior for you? - i.e. if it's all zeros?

@HugoPhibbs
Copy link
Contributor Author

HugoPhibbs commented Aug 11, 2024

re gtest being slower, I can confirm that this is instead probably something to do with my project. When I copy the repro2 file from the MatX repo to mine, it goes from 0.2 secs to 14 seconds.

If it's helpful my cmake looks like:

cmake_minimum_required(VERSION 3.27)
project(DbscanCEOs LANGUAGES CXX CUDA C)
#enable_language(CUDA)
project(sDbscan)

set(CMAKE_CXX_STANDARD 17)
SET(CMAKE_CUDA_ARCHITECTURES 86)
set(CMAKE_CUDA_COMPILER "/usr/local/cuda-12.6/bin/nvcc") # Somehow CLion needs this here (smh)
#SET(CMAKE_C_COMPILER "/usr/bin/g++")
#add_definitions(-DINDEX_64_BIT)

#SET(CMAKE_BUILD_TYPE Debug)

find_package(Eigen3 3.3 REQUIRED NO_MODULE)
find_package(Boost 1.71 REQUIRED NO_MODULE)

# CCCL
include(cmake/CPM.cmake)

# This will automatically clone CCCL from GitHub and make the exported cmake targets available
CPMAddPackage(
    NAME CCCL
    GITHUB_REPOSITORY nvidia/cccl
    GIT_TAG v2.4.0
)

find_package(OpenMP)
if (OPENMP_FOUND)
    set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
    set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
    set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()
set(CUDA_TOOLKIT_ROOT_DIR $ENV{CUDA_HOME})
find_package(CUDAToolkit 12.6 REQUIRED)
#find_package(CUDAToolkit 11.8 REQUIRED)

# ArrayFire
find_package(ArrayFire REQUIRED)

# MatX https://github.com/NVIDIA/MatX
find_package(matx CONFIG REQUIRED)

include_directories(
        ${PROJECT_SOURCE_DIR}/lib/eigen-3.4.0
        ${CUDA_TOOLKIT_ROOT_DIR}/include
        ${gtest_SOURCE_DIR}/include
        ${gtest_SOURCE_DIR}
)

link_directories(
        ${CUDA_TOOLKIT_ROOT_DIR}/lib64
)

add_subdirectory(lib/googletest)

set_source_files_properties(
        test/gsDBSCAN/GsDBSCANTest.cpp
        test/gsDBSCAN/UtilsTest.cpp
        test/gsDBSCAN/PreprocessingTest.cpp
        test/gsDBSCAN/DistancesTest.cpp
        test/gsDBSCAN/ClusteringTest.cpp
        include/gsDBSCAN/clustering.h
        test/repro2.cu
        PROPERTIES LANGUAGE CUDA)

add_executable(${PROJECT_NAME}
        src/main.cpp
        src/Utilities.cpp
        src/dbscan/sDbscan.cpp
        src/fast_copy.c
        src/fht/fht.c
        include/gsDBSCAN/preprocessing.h
        include/gsDBSCAN/utils.h
        include/gsDBSCAN/distances.h
        include/gsDBSCAN/clustering.h
)

add_executable(run_gs_dbscan_tests
        test/gsDBSCAN/GsDBSCANTest.cpp
        test/TestUtils.cpp
        test/gsDBSCAN/UtilsTest.cpp
        test/gsDBSCAN/PreprocessingTest.cpp
        test/gsDBSCAN/DistancesTest.cpp
        test/gsDBSCAN/ClusteringTest.cpp
)

add_executable(repro2
        test/repro2.cu
)


target_link_libraries(run_gs_dbscan_tests PRIVATE CCCL::CCCL CUDA::cudart ArrayFire::afcuda matx::matx gtest gtest_main)
target_link_libraries(${PROJECT_NAME} PRIVATE CCCL::CCCL CUDA::cudart ArrayFire::afcuda Eigen3::Eigen matx::matx)

target_link_libraries(repro2 PRIVATE CCCL::CCCL matx::matx)

@cliffburdick
Copy link
Collaborator

re gtest being slower, I can confirm that this is instead probably something to do with my project. When I copy the repro2 file from the MatX repo to mine, it goes from 0.2 secs to 14 seconds.

If it's helpful my cmake looks like:

cmake_minimum_required(VERSION 3.27)
project(DbscanCEOs LANGUAGES CXX CUDA C)
#enable_language(CUDA)
project(sDbscan)

set(CMAKE_CXX_STANDARD 17)
SET(CMAKE_CUDA_ARCHITECTURES 86)
set(CMAKE_CUDA_COMPILER "/usr/local/cuda-12.6/bin/nvcc") # Somehow CLion needs this here (smh)
#SET(CMAKE_C_COMPILER "/usr/bin/g++")
#add_definitions(-DINDEX_64_BIT)

#SET(CMAKE_BUILD_TYPE Debug)

find_package(Eigen3 3.3 REQUIRED NO_MODULE)
find_package(Boost 1.71 REQUIRED NO_MODULE)

# CCCL
include(cmake/CPM.cmake)

# This will automatically clone CCCL from GitHub and make the exported cmake targets available
CPMAddPackage(
    NAME CCCL
    GITHUB_REPOSITORY nvidia/cccl
    GIT_TAG v2.4.0
)

find_package(OpenMP)
if (OPENMP_FOUND)
    set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
    set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
    set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()
set(CUDA_TOOLKIT_ROOT_DIR $ENV{CUDA_HOME})
find_package(CUDAToolkit 12.6 REQUIRED)
#find_package(CUDAToolkit 11.8 REQUIRED)

# ArrayFire
find_package(ArrayFire REQUIRED)

# MatX https://github.com/NVIDIA/MatX
find_package(matx CONFIG REQUIRED)

include_directories(
        ${PROJECT_SOURCE_DIR}/lib/eigen-3.4.0
        ${CUDA_TOOLKIT_ROOT_DIR}/include
        ${gtest_SOURCE_DIR}/include
        ${gtest_SOURCE_DIR}
)

link_directories(
        ${CUDA_TOOLKIT_ROOT_DIR}/lib64
)

add_subdirectory(lib/googletest)

set_source_files_properties(
        test/gsDBSCAN/GsDBSCANTest.cpp
        test/gsDBSCAN/UtilsTest.cpp
        test/gsDBSCAN/PreprocessingTest.cpp
        test/gsDBSCAN/DistancesTest.cpp
        test/gsDBSCAN/ClusteringTest.cpp
        include/gsDBSCAN/clustering.h
        test/repro2.cu
        PROPERTIES LANGUAGE CUDA)

add_executable(${PROJECT_NAME}
        src/main.cpp
        src/Utilities.cpp
        src/dbscan/sDbscan.cpp
        src/fast_copy.c
        src/fht/fht.c
        include/gsDBSCAN/preprocessing.h
        include/gsDBSCAN/utils.h
        include/gsDBSCAN/distances.h
        include/gsDBSCAN/clustering.h
)

add_executable(run_gs_dbscan_tests
        test/gsDBSCAN/GsDBSCANTest.cpp
        test/TestUtils.cpp
        test/gsDBSCAN/UtilsTest.cpp
        test/gsDBSCAN/PreprocessingTest.cpp
        test/gsDBSCAN/DistancesTest.cpp
        test/gsDBSCAN/ClusteringTest.cpp
)

add_executable(repro2
        test/repro2.cu
)


target_link_libraries(run_gs_dbscan_tests PRIVATE CCCL::CCCL CUDA::cudart ArrayFire::afcuda matx::matx gtest gtest_main)
target_link_libraries(${PROJECT_NAME} PRIVATE CCCL::CCCL CUDA::cudart ArrayFire::afcuda Eigen3::Eigen matx::matx)

target_link_libraries(repro2 PRIVATE CCCL::CCCL matx::matx)

I would issue a make -j VERBOSE=1 in both cases to see which flags are different. If none, then check your CCCL directory and make sure it's the right version in both. If it looks like they're the same, I would try to swap out the CCCL -I paths from the "bad" one with the ones from the "good" one and see if that makes a difference.

@HugoPhibbs
Copy link
Contributor Author

HugoPhibbs commented Aug 12, 2024

I retested this on the main branch of MatX. I'm now getting 13.8 seconds or so for both my repo and within the MatX repo for running repro2.

I made the mistake of running repro2 within the 688-repro branch of the main repo - hence this may explain the difference. Regardless though, I'm still getting quite a long time to run repro2 - whether this was from my own repository or not.

This is what it looks like:

From MatX repo (up to date with 8th August, so it should have the fix for the reshape bug)

:~/Documents/MatX/build$ ./examples/repro2
CUDA call successful: /home/hphi344/Documents/MatX/examples/repro2.cu:174
Total Time taken (again): 13.7543 seconds
70000 500

From my repo

:~/Documents/GS-DBSCAN-CPP/build$ ./repro2
CUDA call successful: /home/hphi344/Documents/GS-DBSCAN-CPP/test/repro2.cu:174
Total Time taken (again): 13.7889 seconds
70000 500

Would you be able to run repro2 on your machine (from within the examples folder of the MatX repo) to see if you get the same behaviour?

Thanks v much

@HugoPhibbs
Copy link
Contributor Author

Hey, sorry if I'm being a bit pestering, but just a quick nudge on this. Have you had a chance to reproduce the same results? Thx for all your help so far.

@cliffburdick
Copy link
Collaborator

Hi @HugoPhibbs , not pestering at all. Got caught up in other things but will take a look tomorrow.

@cliffburdick cliffburdick reopened this Aug 15, 2024
@cliffburdick
Copy link
Collaborator

@HugoPhibbs I was able to reproduce it on the 3090. Will do some more investigating.

@cliffburdick
Copy link
Collaborator

Hi @HugoPhibbs, we tracked down the problem to register spills possibly caused by the iterator size going into CUB. When you materialize the operator into a tensor before calling the norm it improves the speed about 4x. That's still not as fast as cuPy, so we'd like to look into the remaining performance issue until it's faster than cuPy, but we wanted to make sure they're doing the same computation. Can you paste the entire cuPy example including the main function so we can run it standalone?

@HugoPhibbs
Copy link
Contributor Author

@cliffburdick sure can do.

Heres the CuPy code:

import cupy as cp
import numpy as np
from timeit import default_timer as timer

def create_mock_A_B_matrices(n=70_000, k=2, m=2000, D=1024):
    A = np.random.randint(size=(n, 2 * k), low=0, high=2 * (D - 1),
                          dtype=cp.uint32)  # uint32 goes up to 4 billion, more than enough for max n
    B = np.random.randint(size=(2 * D, m), low=0, high=n - 1, dtype=cp.uint32)
    return cp.asarray(A), cp.asarray(B)  # Dunno why but this fixes error

def create_mock_mnist_dataset(n=70_000, d=784):
    """
    Creates a mock MNIST dataset for testing purposes

    Args:
        n: size of dataset
        d: dimensionality of dataset
        m: Number of query vectors per random vector
        k: number of random vectors per dataset vector
        D: number of random vectors

    Returns:
        X, A, B: dataset, A matrix, B matrix
    """
    X_n = np.random.uniform(size=(n, d))
    X = X_n.astype(np.float32)
    return cp.asarray(X)

def get_batch_size(n, d, k, m, alpha=1.2):
    """
    Calculates the batch size.

    Batch size is best a divisor of n, to maximise GPU utilisation (its better when batches are all the same size and the last one isn't a special case)

    Args:
        n: number of query vectors
        d: dimension of the query vectors
        k: k as described above
        m: m as described above
        alpha: a scaling factor to adjust the batch size. Default is 1.2. Mainly here as a toggle to maximise batch size without overloading the memory
    """

    # For some reason, no matter what batch size you choose, after 200 iterations
    # Of the above forloop, the performance falls of a cliff # Can be adjusted to make the batch size smaller or larger

    batch_size = int((n * d * 2 * k * m) // ((1024 ** 3) * alpha))

    if batch_size == 0:
        return n

    for divisor in range(batch_size - 1, 0, -1):
        if n % divisor == 0:
            batch_size = divisor
            break

    return batch_size


def find_distances(X, A, B, alpha=1.2, batch_size = -1):
    """
    Finds the distances between each query vector and their candidate vectors

    I'm curious to see if this is faster than the Numba CUDA implementation

    # TODO add batch processing for X with large dimensions

    Args:
        X: dataset of vectors stored along the ROWS
        A: matrix of indices of the closest and furthest random vectors to each point in X. Has shape (len(X), 2*k)
        B: matrix of indices of the closest and furthest points to each random vector in D. Has shape (2*len(D), m)
        alpha: a scaling factor to adjust the batch size. See docs for get_batch_size for more info
    Returns:
        A matrix of distances between the query and candidate vectors. Has shape (len(X), 2*k*m)
    """
    k = A.shape[1] // 2
    m = B.shape[1]

    n = X.shape[0]
    d = X.shape[1]

    batch_size = batch_size if batch_size != -1 else get_batch_size(n, d, k, m, alpha=alpha)

    distances = cp.empty(shape=(n, 2 * k * m),
                         dtype=cp.float16)  # float32 causes a memory overload. float16 is fine (for eps 2DP)

    for i in range(0, n, batch_size):
        max_batch_idx = min(i + batch_size, X.shape[0])

        Z_batch = X[B[A[i:max_batch_idx]]]
        Z_batch_adj = Z_batch.reshape(batch_size, 2 * k * m,  d)

        Y_batch = Z_batch_adj - X[i:max_batch_idx, cp.newaxis, :]

        distances[i:max_batch_idx] = cp.linalg.norm(Y_batch, axis=2)

    return distances

k = 5
n = 70_000
m = 50
D = 1024
d = 784

X = create_mock_mnist_dataset(n=n, d=d)
print(X.shape)

A, B = create_mock_A_B_matrices(n=X.shape[0], k=k, m=m, D=D)
print(A.shape)
print(B.shape)

start = timer()
distances = find_distances(X, A, B, batch_size=2000)
cp.cuda.Device().synchronize()
print(f"Test took {timer() - start} seconds")
print(distances.shape)

I'm quite confident that they are doing the same computation as they produce identical results for n=1000, and its just doing the same thing over and over again with batches anyway.

@cliffburdick
Copy link
Collaborator

I'm getting 40s on a 3090 in Python with your code. Is that what you expect? It seems a lot slower than what you said above.

@HugoPhibbs
Copy link
Contributor Author

@cliffburdick that is very odd, not what I'd expect at all.

I'm getting 0.73 seconds or so:

Test took 0.7395999459549785 seconds
(70000, 500)
k: 5, n: 70000, m: 50, D: 1024, d: 784

See this Gist containing updated print statements: https://gist.github.com/HugoPhibbs/ba3ae26c9ff09ea997ece53c9b856399

My CuPy package is cupy-cuda12x 12.3.0

@cliffburdick
Copy link
Collaborator

@cliffburdick that is very odd, not what I'd expect at all.

I'm getting 0.73 seconds or so:

Test took 0.7395999459549785 seconds
(70000, 500)
k: 5, n: 70000, m: 50, D: 1024, d: 784

See this Gist containing updated print statements: https://gist.github.com/HugoPhibbs/ba3ae26c9ff09ea997ece53c9b856399

My CuPy package is cupy-cuda12x 12.3.0

You're right, it was a caching issue. I'm getting close to yours when the JIT has completed.

@HugoPhibbs
Copy link
Contributor Author

Hi @HugoPhibbs, we tracked down the problem to register spills possibly caused by the iterator size going into CUB. When you materialize the operator into a tensor before calling the norm it improves the speed about 4x. That's still not as fast as cuPy, so we'd like to look into the remaining performance issue until it's faster than cuPy, but we wanted to make sure they're doing the same computation. Can you paste the entire cuPy example including the main function so we can run it standalone?

Hi @cliffburdick

Just so we're on the same page, what is the exact code fix you applied to materialise the operator into a tensor before calling the norm?

@cliffburdick
Copy link
Collaborator

Hi @HugoPhibbs, we tracked down the problem to register spills possibly caused by the iterator size going into CUB. When you materialize the operator into a tensor before calling the norm it improves the speed about 4x. That's still not as fast as cuPy, so we'd like to look into the remaining performance issue until it's faster than cuPy, but we wanted to make sure they're doing the same computation. Can you paste the entire cuPy example including the main function so we can run it standalone?

Hi @cliffburdick

Just so we're on the same page, what is the exact code fix you applied to materialise the operator into a tensor before calling the norm?

Hi @HugoPhibbs, it was:

        auto tmp_tensor = make_tensor<typename YBatch_t_op::value_type>(Shape(YBatch_t_op), MATX_ASYNC_DEVICE_MEMORY, 0);
        (tmp_tensor = YBatch_t_op).run();
        auto YBatch_t_norm_op = matx::vector_norm(tmp_tensor, {2}, matx::NormOrder::L2);

@cliffburdick
Copy link
Collaborator

@HugoPhibbs we have a fix/suggestion that will get you to 2x faster than cuPy, but we still have some work to do to go faster. We should be able to commit the first patch soon.

@cliffburdick
Copy link
Collaborator

@HugoPhibbs can you please pull the latest commit and compile with -DMATX_BUILD_32_BIT=ON?

@HugoPhibbs
Copy link
Contributor Author

HugoPhibbs commented Sep 21, 2024

Hi @cliffburdick thanks for the help.

I tried the change (rebuilt and installed etc), it speed up the runtime certainly, however, it takes 5 seconds still (an improvement nonethess).

I've since reimplemented the code in question with Torch. The runtime of torch is 0.79 seconds (very similar to CupY), while the runtime of matx is around 5 seconds

Please see the attached profiles for running my overall algorithm - I ran it by running the code in question (its a part of an overall algorithm) with torch, and then with matx. The end-results of both runs are very similar. Except the MatX takes remarkably longer - the profile looks much different too.

profiles.zip

@cliffburdick
Copy link
Collaborator

Thanks @HugoPhibbs , was this still a 3090? I was getting 0.5s when I tested on that card, but I will retry.

@HugoPhibbs
Copy link
Contributor Author

yeah.

Actually, could you send the code you're using?

@cliffburdick
Copy link
Collaborator

We are using the repro.cu from this branch:

90bf114

@HugoPhibbs
Copy link
Contributor Author

Hi @cliffburdick

Sorry for the delay,

I brought changes from #90bf114 into the latest from the main. And ran it, I had to disable the treat warnings as errors setting in the main cmake.

The results are:

Total Time taken: 2.08292 seconds
Total Time taken (again): 2.0831 seconds
70000 500

@cliffburdick
Copy link
Collaborator

Hi @HugoPhibbs . I rebased 686-repro and pushed. After doing that and building with 32b support on a 3090, here are my results:

YBatch_t_op:
    Shape:
        dim: 0: size: 250
        dim: 1: size: 500
        dim: 2: size: 784
Sync Time taken: 0.75648 seconds
Total Time taken: 0.779154 seconds
Total Time taken (again): 0.779977 seconds

Are you running the example exactly as it is?

@HugoPhibbs
Copy link
Contributor Author

HugoPhibbs commented Sep 24, 2024

Hi @cliffburdick

mb, forgot to build with 32 bit. my results are:

YBatch_t_op: 
    Shape: 
        dim: 0: size: 250
        dim: 1: size: 500
        dim: 2: size: 784
Sync Time taken: 0.92821 seconds
Total Time taken: 1.05661 seconds
Total Time taken (again): 1.05698 seconds

These are different from yours - even so, your results are still about the same speed as CuPy?

For debugging my nvidia-smi:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.28.03              Driver Version: 560.28.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0 Off |                  N/A |
| 56%   56C    P0            131W /  390W |    1499MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        Off |   00000000:4A:00.0 Off |                  N/A |
|  0%   48C    P8             47W /  390W |      88MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1985      G   /usr/lib/xorg/Xorg                            161MiB |
|    0   N/A  N/A      3172      G   /usr/lib/xorg/Xorg                           1048MiB |
|    0   N/A  N/A      3357      G   /usr/bin/gnome-shell                           50MiB |
|    0   N/A  N/A   1165203      G   ...ures=SpareRendererForSitePerProcess          4MiB |
|    0   N/A  N/A   3596877      G   krusader                                        4MiB |
|    0   N/A  N/A   4047909      G   ...seed-version=20240904-050056.718000         58MiB |
|    1   N/A  N/A      1985      G   /usr/lib/xorg/Xorg                             34MiB |
|    1   N/A  N/A      3172      G   /usr/lib/xorg/Xorg                             34MiB |
+-----------------------------------------------------------------------------------------+

@cliffburdick
Copy link
Collaborator

Hi @HugoPhibbs, yes, it seems to be about the same as pytorch. Previously I was comparing to your cupy results. We have further optimizations we can make, but I think before doing that we're better off looking at your python and reimplementing it versus trying to optimize something that may not represent the original problem. Are you able to post your pytorch code?

@HugoPhibbs
Copy link
Contributor Author

HugoPhibbs commented Sep 24, 2024

Hi @cliffburdick.

Yep sure, I'm using LibTorch (C++ Torch) though - not PyTorch.

See my code here.

[Edit]: Realised LibTorch might not be all that helpful - there is no run script for random data much like what we are doing above.

Here is the PyTorch code I used for prototyping here. The result is:

Time for the loop: 0.2770 seconds
GPU took 659.4078369140625 milliseconds
Test took 0.6594308433122933 seconds
torch.Size([70000, 500])
k: 5, n: 70000, m: 50, D: 1024, d: 784

@cliffburdick
Copy link
Collaborator

@HugoPhibbs can you provide a main function too so we can run the whole thing?

@HugoPhibbs
Copy link
Contributor Author

Hi @cliffburdick,

Do you mean for the LibTorch? (LibTorch is a little hard to set up)

The Pytorch code has code at the bottom you can run.

@cliffburdick
Copy link
Collaborator

Hi @cliffburdick,

Do you mean for the LibTorch? (LibTorch is a little hard to set up)

The Pytorch code has code at the bottom you can run.

I missed that. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants