From afba567eac3d58bb29abbe99e58eb58858d646c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joan=20Alexis=20Glaun=C3=A8s?= Date: Tue, 2 Apr 2024 20:13:52 +0000 Subject: [PATCH] update --- pykeops/pykeops/sandbox/issue_362.py | 39 +++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/pykeops/pykeops/sandbox/issue_362.py b/pykeops/pykeops/sandbox/issue_362.py index d0318464..3aa56f46 100644 --- a/pykeops/pykeops/sandbox/issue_362.py +++ b/pykeops/pykeops/sandbox/issue_362.py @@ -1,5 +1,6 @@ import torch from pykeops.torch import LazyTensor +from time import time A, B, C, D = 32, 8, 16, 400 x = torch.randn(A, B, 1, D).unsqueeze(2).cuda() @@ -8,16 +9,40 @@ # x.shape: (A, B, 1, 1, D) # w.shape: (A, 1, C, D, D) +start = time() +res_torch_0 = torch.einsum("abde,ace->abd", w.view(A, C, D, D), x.view(A,B,D)) +end = time() +print("time for torch 0:", end-start) + +start = time() res_torch = (x*w).sum(axis=-1).sum(axis=1) -print(res_torch.shape) +end = time() +print("time for torch:", end-start) + +print(torch.norm(res_torch_0-res_torch)/torch.norm(res_torch)) +start = time() xi = LazyTensor(x.view(A, B, 1, D)) wi = LazyTensor(w.view(A, 1, C*D, D)) - res_keops = (xi | wi).sum(axis=1).view(A,C,D) -print(res_keops.shape) +print((xi | wi).sum(axis=1).shape) +end = time() +print("time for keops:", end-start) + +print(torch.norm(res_keops-res_torch)/torch.norm(res_torch)) + +start = time() +xp = x.permute(0,2,3,4,1)[...,None,:].contiguous() +wp = w.permute(0,2,3,4,1)[...,None,:].contiguous() +end1 = time() +# xp.shape: (A, 1, 1, D, 1, B) +# wp.shape: (A, C, D, D, 1, 1) + +xi = LazyTensor(xp) +wi = LazyTensor(wp) + +res_keops_alt = (xi*wi).sum(axis=-1).sum(axis=3).view(A,C,D) +end = time() +print("time for keops alt:", end-start, "(", end1-start, "for permute)") -print(torch.norm((res_keops-res_torch)/res_torch)) -print(torch.max((res_keops-res_torch)/res_torch)) -print(torch.min((res_keops-res_torch)/res_torch)) -print(torch.mean(torch.abs((res_keops-res_torch)/res_torch))) +print(torch.norm(res_keops_alt-res_torch)/torch.norm(res_torch))