Skip to content

Commit

Permalink
testing NL means denoising
Browse files Browse the repository at this point in the history
  • Loading branch information
joanglaunes committed Mar 14, 2022
1 parent 4140789 commit a4bf92f
Show file tree
Hide file tree
Showing 4 changed files with 1,474 additions and 0 deletions.
1,131 changes: 1,131 additions & 0 deletions pykeops/pykeops/sandbox/Non_Local_Means_for_denoising.ipynb

Large diffs are not rendered by default.

Binary file added pykeops/pykeops/sandbox/fruits.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
194 changes: 194 additions & 0 deletions pykeops/pykeops/sandbox/plot_image_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
=============
Image Patches
=============
Example of operation over image patches with KeOps.
Given an image I of shape (m,n,d), we compute the K nearest patches of
every s-by-s image patch of the image
"""

###############################################################
# Setup
# -------------
# Standard imports:

import torch
from pykeops.torch import LazyTensor

id_device = 0
device = 'cuda:'+str(id_device) if torch.cuda.is_available() else 'cpu'

###############################################################
# Parameters (s size of block, K number of nearest neighbours)
s, K = 5, 10

###############################################################
# Loading image
import imageio
if 'cuda' in device:
imfile = "fruits.png" #"https://homepages.cae.wisc.edu/~ece533/images/fruits.png"
else:
imfile = "http://helios.math-info.univ-paris5.fr/~glaunes/pikachu.bmp"
I = torch.tensor(imageio.imread(imfile)).float().to(device)
m, n, d = I.shape

###############################################################
# PyTorch implementation (works only for small images)
# ----------------------------------------------------
###############################################################
# function to build image patches with PyTorch
def torch_patches(I,s):
# inputs : I torch tensor of shape (m,n,d), image
# s : integer, size of patches
# output : torch tensor of shape ((m-s+1)*(n-s+1),d*s**2), matrix of all s-by-s patches of I
m, n, d = I.shape
P = torch.zeros((m-s+1)*(n-s+1),s**2,d)
k = 0
for i in range(s):
for j in range(s):
P[:,k,:] = I[i:i+m-s+1,j:j+n-s+1,:].flatten().view(-1,d)
k += 1
return P.view(-1,d*s**2)



###############################################################
# Testing with PyTorch if image is not too large
import time
if m*n<20000:
start = time.time()
P = torch_patches(I,s)
D2 = ((P[:,None,:]-P[None,:,:])**2).sum(dim=2)
out_torch = D2.argsort(dim=1)[:,:K]
print("elapsed time with PyTorch : ",time.time()-start)

###############################################################
# KeOps implementation
# ----------------------------------------------------

###############################################################
# Function to represent image patches as a LazyTensor

def LazyTensor_patches(I,s,axis,use_ranges=True):
# input : I torch tensor of shape (m,n,d), image
# s : integer, size of patches
# output : LazyTensor of shape ((m-s+1)*n-s+1,d*s**2) representing all s-by-s patches of I,
# N.B. there are (m-s)*(s-1) "garbage" patches, corresponding to indices (i,m-s+1),...(i,m-1)
# for each row i except the last one. These patches will not be taken into account in computations
# if use_ranges=True, but will remain in any case in the output as zero valued rows.
m, n, d = I.shape
I = I.view(m*n,d)
ind_last = (m-s+1)*n-s+1
for i in range(s):
for j in range(s):
if i==0 and j==0:
P = LazyTensor(torch.narrow(I,0,0,ind_last),axis=axis)
else:
ind_shift = i*n + j
P = P.concat(LazyTensor(torch.narrow(I,0,ind_shift,ind_last),axis=axis))
if use_ranges:
# we define the range of computation using the range arguments of KeOps
# This is used to avoid computing over "garbage" patches
ranges = torch.tensor([[0,ind_last]]).int().to(I.device)
slices = torch.tensor([m-s+1]).int().to(I.device)
rgm = torch.arange(m-s+1)[:,None].int()
red_ranges = torch.cat((rgm*n,rgm*n+n-s+1),dim=1).to(I.device)
return P, (ranges,slices,red_ranges,ranges,slices,red_ranges)
else:
return P

###############################################################
# Testing with KeOps - using ranges

start = time.time()

# creating LazyTensor objects for patches and computing
P_i, ranges = LazyTensor_patches(I,s,axis=0)
P_j, ranges = LazyTensor_patches(I,s,axis=1)
D2 = P_i.sqdist(P_j)
out_keops = D2.argKmin(dim=1, K=K, ranges=ranges)

# now post-processing to get rid of the garbage rows of the output
ind_keep = torch.arange(m*n).view(m,n)[:-s+1,:-s+1].flatten()
out_keops = out_keops[ind_keep,:]
# Also here the output corresponds to patch indices, for the list of patches that includes
# garbage patches. So we need to convert the indices
q = torch.div(out_keops, n, rounding_mode='floor')
r = out_keops % n
out_keops = q*(n-s+1)+r

print("elapsed time with KeOps (ranges) : ",time.time()-start)


###############################################################
# Testing with KeOps - using padding and no ranges

###############################################################
# padding the image with very large values on the right border
Ipad = torch.cat((I,1e20*torch.ones(m,1,d, device=device)),dim=1)

start = time.time()

# creating LazyTensor objects for patches and computing
P_i = LazyTensor_patches(Ipad,s,axis=0,use_ranges=False)
P_j = LazyTensor_patches(Ipad,s,axis=1,use_ranges=False)
D2 = P_i.sqdist(P_j)
out_keops = D2.argKmin(dim=1,K=K)

# now post-processing to get rid of the garbage rows of the output
ind_keep = torch.arange(m*(n+1)).view(m,n+1)[:-s+1,:-s].flatten()
out_keops = out_keops[ind_keep,:]
# Also here the output corresponds to patch indices, for the list of patches that includes
# garbage patches. So we need to convert the indices
q = torch.div(out_keops, n+1, rounding_mode='floor')
r = out_keops % (n+1)
out_keops = q*(n-s+1)+r

print("elapsed time with KeOps (padding) : ",time.time()-start)

###############################################################
# displaying the image with some results
# ----------------------------------------------------

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
I = I.cpu().numpy().astype(np.uint8)
def ShowNearestPatches(I,i,j):
def add_patch_box(I,i,j,s,clr):
B = I.copy()
B[i:i+s,j,:] = clr
B[i:i+s,j+s-1,:] = clr
B[i,j:j+s,:] = clr
B[i+s-1,j:j+s,:] = clr
return np.mean([I,B],axis=0).astype(np.uint8)
def add_segment(I,i,j,ik,jk,clr):
B = I.copy()
n = 100
iseg = np.linspace(i,ik,n).astype(np.int32)
jseg = np.linspace(j,jk,n).astype(np.int32)
for k in range(n):
B[iseg[k],jseg[k],:] = clr
return np.mean([I,B],axis=0).astype(np.uint8)
red = np.array([255,0,0]).astype(np.uint8)
blue = np.array([0,0,255]).astype(np.uint8)
ind = i*(n-s+1)+j
for k in range(1,K):
indk = out_keops[ind,k].cpu().numpy().astype(np.int32)
ik, jk = indk//(n-s+1), indk%(n-s+1)
I = add_patch_box(I,ik,jk,s,blue)
I = add_segment(I,i+s//2,j+s//2,ik+s//2,jk+s//2,blue)
I = add_patch_box(I,i,j,s,red)
return I
for k in range(10 if 'cuda' in device else 1):
i, j = np.random.randint(m-s+1), np.random.randint(n-s+1)
I = ShowNearestPatches(I,i,j)
I = Image.fromarray(I, 'RGB')
fig = plt.figure(figsize=(6,6))
fig.suptitle(f'one {s}x{s} patch and its {K} nearest neighbours')
plt.imshow(I)
plt.show()


149 changes: 149 additions & 0 deletions pykeops/pykeops/sandbox/plot_non_local_means.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@

"""
=======================================
Image denoising with non local means
=======================================
Another example of operation over image patches with KeOps.
Denoising via non local means (brute-force)
"""

###############################################################
# Setup
# -------------
# Standard imports:

import torch
from pykeops.torch import LazyTensor

id_device = 1
device = 'cuda:'+str(id_device) if torch.cuda.is_available() else 'cpu'

###############################################################
# Parameters (s size of block, K number of nearest neighbours)
K = 10

###############################################################
# Loading image
import imageio
if 'cuda' in device:
imfile = "fruits.png" #"https://homepages.cae.wisc.edu/~ece533/images/fruits.png"
s = 5
else:
imfile = "http://helios.math-info.univ-paris5.fr/~glaunes/pikachu.bmp"
s = 3
sd2 = s//2
I = torch.tensor(imageio.imread(imfile)).float().to(device)
sigma = 30.
I += torch.randn(I.shape, device=device)*sigma
I = I.clamp(0,255)
h2 = (1.*sigma)**2

m, n, d = I.shape

###############################################################
# PyTorch implementation (works only for small images)
# ----------------------------------------------------
###############################################################
# function to build image patches with PyTorch
def torch_patches(I,s):
# inputs : I torch tensor of shape (m,n,d), image
# s : integer, size of patches
# output : torch tensor of shape ((m-s+1)*(n-s+1),d*s**2), matrix of all s-by-s patches of I
m, n, d = I.shape
P = torch.zeros((m-s+1)*(n-s+1),s**2,d)
k = 0
for i in range(s):
for j in range(s):
P[:,k,:] = I[i:i+m-s+1,j:j+n-s+1,:].flatten().view(-1,d)
k += 1
return P.view(-1,d*s**2)

###############################################################
# Testing with PyTorch if image is not too large
import time
if m*n<20000:
start = time.time()
P = torch_patches(I,s)
D2 = ((P[:,None,:]-P[None,:,:])**2).sum(dim=2)
K = (-D2/(d*s**2*h2)).exp()
C = K.sum(dim=1)
out_torch = (K[:,:,None]*I[sd2:-sd2,sd2:-sd2,:].reshape((1,(m-s+1)*(n-s+1),d))).sum(dim=1) / C[:,None]
out_torch = out_torch.reshape((m-s+1,n-s+1,d))
print("elapsed time with PyTorch : ",time.time()-start)


###############################################################
# KeOps implementation
# ----------------------------------------------------

###############################################################
# Function to represent image patches as a LazyTensor

def LazyTensor_patches(I,s,axis,use_ranges=True):
# input : I torch tensor of shape (m,n,d), image
# s : integer, size of patches
# output : LazyTensor of shape ((m-s+1)*n-s+1,d*s**2) representing all s-by-s patches of I,
# N.B. there are (m-s)*(s-1) "garbage" patches, corresponding to indices (i,m-s+1),...(i,m-1)
# for each row i except the last one. These patches will not be taken into account in computations
# if use_ranges=True, but will remain in any case in the output as zero valued rows.
m, n, d = I.shape
I = I.view(m*n,d)
ind_last = (m-s+1)*n-s+1
for i in range(s):
for j in range(s):
if i==0 and j==0:
P = LazyTensor(torch.narrow(I,0,0,ind_last),axis=axis)
else:
ind_shift = i*n + j
P = P.concat(LazyTensor(torch.narrow(I,0,ind_shift,ind_last),axis=axis))
if use_ranges:
# we define the range of computation using the range arguments of KeOps
# This is used to avoid computing over "garbage" patches
ranges = torch.tensor([[0,ind_last]]).int().to(I.device)
slices = torch.tensor([m-s+1]).int().to(I.device)
rgm = torch.arange(m-s+1)[:,None].int()
red_ranges = torch.cat((rgm*n,rgm*n+n-s+1),dim=1).to(I.device)
return P, (ranges,slices,red_ranges,ranges,slices,red_ranges)
else:
return P

###############################################################
# Testing with KeOps - using ranges
if True:
start = time.time()
# creating LazyTensor objects for patches and computing
P_i, ranges = LazyTensor_patches(I,s,axis=0)
P_j, ranges = LazyTensor_patches(I,s,axis=1)
D2 = P_i.sqdist(P_j)
K = (-D2/(d*s**2*h2)).exp()
C = K.sum(dim=1)
ind_keep = torch.arange(m*n).view(m,n)[:-s+1,:-s+1].flatten()
I_ = torch.zeros((m-s+1)*n-s+1,d, device=device)
I_[ind_keep,:] = I[sd2:-sd2,sd2:-sd2,:].reshape((m-s+1)*(n-s+1),d)
out_keops = (K*I_[None,:,:]).sum(dim=1) / C
out_keops = out_keops[ind_keep,:]
out_keops = out_keops.reshape((m-s+1,n-s+1,d))
print("elapsed time with KeOps (ranges) : ",time.time()-start)




###############################################################
# displaying the results
# ----------------------------------------------------

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
def plot_image(I, title=""):
I = I.cpu().numpy().astype(np.uint8)
I = Image.fromarray(I, 'RGB')
fig = plt.figure(figsize=(6,6))
fig.suptitle(title)
plt.imshow(I)

plot_image(I, "noisy image")
plot_image(out_keops, "denoised image")
plt.show()

0 comments on commit a4bf92f

Please sign in to comment.