From b08e0c4abfa56aaf646708cb08e8c38da4483144 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joan=20Alexis=20Glaun=C3=A8s?= Date: Fri, 8 Dec 2023 16:48:35 +0000 Subject: [PATCH] update --- pykeops/pykeops/sandbox/test_cupy.py | 31 ++++-- pykeops/pykeops/sandbox/test_cupy_dtw.py | 125 +++++++++++++++++++++++ 2 files changed, 148 insertions(+), 8 deletions(-) create mode 100644 pykeops/pykeops/sandbox/test_cupy_dtw.py diff --git a/pykeops/pykeops/sandbox/test_cupy.py b/pykeops/pykeops/sandbox/test_cupy.py index 48aaac23..eea7e447 100644 --- a/pykeops/pykeops/sandbox/test_cupy.py +++ b/pykeops/pykeops/sandbox/test_cupy.py @@ -3,7 +3,9 @@ from time import time from cupyx.profiler import benchmark import torch -from pykeops.torch import LazyTensor +from pykeops.torch import Genred + +cp.cuda.runtime.deviceSynchronize() loaded_from_source = r''' extern "C"{ @@ -65,13 +67,20 @@ def ker_gauss_raw(out,x,y,b): gridsize = 1+(N-1)//blocksize shared_mem = 4*blocksize*2 ker_gauss((gridsize,),(blocksize,),(N,N,out,x,y,b),shared_mem=shared_mem) + cp.cuda.runtime.deviceSynchronize() def ker_gauss_cupy(out,x,y,b): cp.sum(cp.exp(-(x[:,None]-y[None,:])**2) * b[None,:],axis=1,out=out) + cp.cuda.runtime.deviceSynchronize() def ker_gauss_torch(out,x,y,b): torch.sum(torch.exp(-(x[:,None]-y[None,:])**2) * b[None,:],axis=1,out=out) + cp.cuda.runtime.deviceSynchronize() +fun_keops = Genred("Exp(-Square(X-Y))*B", ["X=Vi(0,1)", "Y=Vj(1,1)", "B=Vj(2,1)"], axis=1) +def ker_gauss_keops(out,x,y,b): + fun_keops(x[:,None],y[:,None],b[:,None],out=out) + cp.cuda.runtime.deviceSynchronize() def bench_time(fun,args,n_repeat=1): for k in range(2): @@ -81,18 +90,24 @@ def bench_time(fun,args,n_repeat=1): end = time() return f"time for {fun.__name__} : {end-start}" -N = 10000 -n_repeat = 1 +N = 100000 +n_repeat = 10 x = cp.random.rand(N, dtype=cp.float32) y = cp.random.rand(N, dtype=cp.float32) b = cp.random.rand(N, dtype=cp.float32) out = cp.zeros(N, dtype=cp.float32) +out_ref = cp.zeros(N, dtype=cp.float32) + +xt = torch.as_tensor(x, device='cuda') +yt = torch.as_tensor(y, device='cuda') +bt = torch.as_tensor(b, device='cuda') +outt = torch.as_tensor(out, device='cuda') for bench_method in (bench_time,benchmark): print(f"\n----------------------\nUsing {bench_method.__name__}\n---------------------") - print(bench_method(ker_gauss_raw,(out,x,y,b),n_repeat=n_repeat)) + print(bench_method(ker_gauss_raw,(out_ref,x,y,b),n_repeat=n_repeat)) + print(bench_method(ker_gauss_keops,(outt,xt,yt,bt),n_repeat=n_repeat)) + print("relative error : ", cp.linalg.norm(outt-out_ref)/cp.linalg.norm(out_ref)) if N<20000: - out_ref = cp.zeros(N, dtype=cp.float32) - print(bench_method(ker_gauss_cupy,(out_ref,x,y,b),n_repeat=n_repeat)) - print("relative error : ", cp.linalg.norm(out-out_ref)/cp.linalg.norm(out_ref)) - print(bench_method(ker_gauss_torch,(out_ref,x,y,b),n_repeat=n_repeat)) \ No newline at end of file + print(bench_method(ker_gauss_cupy,(out,x,y,b),n_repeat=n_repeat)) + print(bench_method(ker_gauss_torch,(outt,xt,yt,bt),n_repeat=n_repeat)) \ No newline at end of file diff --git a/pykeops/pykeops/sandbox/test_cupy_dtw.py b/pykeops/pykeops/sandbox/test_cupy_dtw.py new file mode 100644 index 00000000..d85bf899 --- /dev/null +++ b/pykeops/pykeops/sandbox/test_cupy_dtw.py @@ -0,0 +1,125 @@ +import cupy as cp +import numpy as np +from time import time +from cupyx.profiler import benchmark + +cp.cuda.runtime.deviceSynchronize() + +loaded_from_source = r''' + +#define MIN(x,y) (((x) < (y)) ? (x) : (y)) +#define MIN(x,y,z) (MIN(x,MIN(y,z))) + +#include +using namespace cooperative_groups; + +extern "C"{ + +__global__ void dtw(int nx, int ny, float *x, float *y, float *out) +{ + int blocksize = blockDim.x; + int iblock = blockIdx.x; + + // center buffer + float *bufref = out + nx - iblock * blocksize; + float *buf = bufref; + + // get the index of the current thread + int iloc = threadIdx.x; + int i = indblock * blocksize + iloc; + int ibuf = -iloc-1; + + // declare shared mem - size to allocate must be 4*blocksize+1 + extern __shared__ float shared[]; + float *yloc = shared; + float *bufloc = yloc + 2*blocksize + 1; + int buflocsize = 3*blocksize + 1; + + float xi, d2ij; + + if (i < nx) { + xi = x[i]; + bufloc[ibuf] = buf[ibuf]; + if (iloc==0) + bufloc[0] = buf[0]; + + for (int jblock=0, jstart=0; jstart