Skip to content

Commit

Permalink
compare cube to per slice hess
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Sep 10, 2024
1 parent fd195a1 commit 8d0d3e8
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 53 deletions.
9 changes: 5 additions & 4 deletions pfb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
__version__ = '0.0.4'

def set_envs(nthreads, ncpu):
os.environ["OMP_NUM_THREADS"] = str(nthreads)
os.environ["OPENBLAS_NUM_THREADS"] = str(nthreads)
os.environ["MKL_NUM_THREADS"] = str(nthreads)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(nthreads)
os.environ["OMP_NUM_THREADS"] = '2'
os.environ["OPENBLAS_NUM_THREADS"] = '2'
os.environ["MKL_NUM_THREADS"] = '2'
os.environ["VECLIB_MAXIMUM_THREADS"] = '2'
os.environ["NPY_NUM_THREADS"] = '2'
os.environ["NUMBA_NUM_THREADS"] = str(nthreads)
os.environ["JAX_PLATFORMS"] = 'cpu'
os.environ["JAX_ENABLE_X64"] = 'True'
Expand Down
8 changes: 4 additions & 4 deletions pfb/operators/gridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,13 @@ def image_data_products(dsl,
x = np.linspace(-5, 5, 150)
y = norm.pdf(x, 0, 1)
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(8, 12))
ax[0,0].hist((residual_vis.real*wgtp).ravel(), bins=15, density=True)
ax[0,0].hist((residual_vis.real*np.sqrt(wgtp/2)).ravel(), bins=15, density=True)
ax[0,0].plot(x, y, 'k')
ax[0,1].hist((residual_vis.real*wgt).ravel(), bins=15, density=True)
ax[0,1].hist((residual_vis.real*np.sqrt(wgt/2)).ravel(), bins=15, density=True)
ax[0,1].plot(x, y, 'k')
ax[1,0].hist((residual_vis.imag*wgtp).ravel(), bins=15, density=True)
ax[1,0].hist((residual_vis.imag*np.sqrt(wgtp/2)).ravel(), bins=15, density=True)
ax[1,0].plot(x, y, 'k')
ax[1,1].hist((residual_vis.imag*wgt).ravel(), bins=15, density=True)
ax[1,1].hist((residual_vis.imag*np.sqrt(wgt/2)).ravel(), bins=15, density=True)
ax[1,1].plot(x, y, 'k')
import os
cwd = os.getcwd()
Expand Down
150 changes: 105 additions & 45 deletions pfb/operators/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def _hessian_psf_slice(x, # input image, not overwritten
Tikhonov regularised Hessian approx
"""
nx, ny = x.shape
xpad[nx:, :] = 0.0
xpad[0:nx, ny:] = 0.0
xpad.fill(0.0)
if beam is None:
xpad[0:nx, 0:ny] = x
np.copyto(xpad[0:nx, 0:ny], x)
# xpad[0:nx, 0:ny] = x
else:
xpad[0:nx, 0:ny] = x*beam
r2c(xpad, axes=(0, 1), nthreads=nthreads,
Expand All @@ -127,7 +127,7 @@ def _hessian_psf_slice(x, # input image, not overwritten
c2r(xhat, axes=(0, 1), forward=False, out=xpad,
lastsize=lastsize, inorm=2, nthreads=nthreads,
allow_overwriting_input=True)
xout[...] = xpad[0:nx, 0:ny]
np.copyto(xout, xpad[0:nx, 0:ny])

if beam is not None:
xout *= beam
Expand Down Expand Up @@ -189,7 +189,7 @@ def hess_direct(x, # input image, not overwritten
eta=1,
mode='forward'):
nband, nx, ny = x.shape
xpad[...] = 0.0
xpad.fill(0.0)
xpad[:, 0:nx, 0:ny] = x * taperxy[None]
r2c(xpad, out=xhat, axes=(1,2),
forward=True, inorm=0, nthreads=nthreads)
Expand All @@ -200,7 +200,7 @@ def hess_direct(x, # input image, not overwritten
c2r(xhat, axes=(1, 2), forward=False, out=xpad,
lastsize=lastsize, inorm=2, nthreads=nthreads,
allow_overwriting_input=True)
xout[...] = xpad[:, 0:nx, 0:ny]
np.copyto(xout, xpad[:, 0:nx, 0:ny])
xout *= taperxy[None]
return xout

Expand All @@ -219,7 +219,7 @@ def hess_direct_slice(x, # input image, not overwritten
Note eta must be relative to wsum (peak of PSF)
'''
nx, ny = x.shape
xpad[...] = 0.0
xpad.fill(0.0)
xpad[0:nx, 0:ny] = x * taperxy
r2c(xpad, out=xhat, axes=(0,1),
forward=True, inorm=0, nthreads=nthreads)
Expand All @@ -230,7 +230,7 @@ def hess_direct_slice(x, # input image, not overwritten
c2r(xhat, axes=(0, 1), forward=False, out=xpad,
lastsize=lastsize, inorm=2, nthreads=nthreads,
allow_overwriting_input=True)
xout[...] = xpad[0:nx, 0:ny]
np.copyto(xout, xpad[0:nx, 0:ny])
xout *= taperxy
return xout

Expand All @@ -256,6 +256,7 @@ def __init__(self, nx, ny, abspsf,
assert self.ny == beam.shape[2]
self.beam = beam
else:
# self.beam = None
self.beam = (None,)*self.nband
self.ny_psf = 2*(self.nyo2-1)
self.nx_pad = self.nx_psf - self.nx
Expand All @@ -277,6 +278,10 @@ def __init__(self, nx, ny, abspsf,
self.xpad = empty_noncritical((self.nx_psf, self.ny_psf),
dtype='f8')
# output cube
# self.xhat = empty_noncritical((self.nband, self.nx_psf, self.nyo2),
# dtype='c16')
# self.xpad = empty_noncritical((self.nband, self.nx_psf, self.ny_psf),
# dtype='f8')
self.xout = empty_noncritical((self.nband, self.nx, self.ny),
dtype='f8')

Expand All @@ -291,18 +296,6 @@ def __init__(self, nx, ny, abspsf,

# for beam application in direct mode
self.min_beam = min_beam

self.xpad[...] = 0.0
self.xpad[nx:, :] = 1.0
self.xpad[0:nx, ny:] = 1.0
self.pad_mask = self.xpad>0
self.ix_pad = np.argwhere(self.pad_mask)[:, 0]
self.iy_pad = np.argwhere(self.pad_mask)[:, 1]
# self.xpad[...] = 0.0
# self.xpad[self.ix_pad, self.iy_pad] = 1.0
# import matplotlib.pyplot as plt
# plt.imshow(self.pad_mask.astype(np.float64))
# plt.show()
# import ipdb; ipdb.set_trace()

def dot(self, x):
Expand All @@ -318,38 +311,26 @@ def dot(self, x):
assert nx == self.nx
assert ny == self.ny

tr2c = 0.0
tconv = 0.0
tc2r = 0.0
ttot = 0.0
tii = time()
for b in range(nband):
self.xpad[nx:, :] = 0.0
self.xpad[0:nx, ny:] = 0.0
self.xpad.fill(0.0)
if self.beam[b] is None:
self.xpad[0:nx, 0:ny] = xtmp[b]
np.copyto(self.xpad[0:nx, 0:ny], xtmp[b])
else:
self.xpad[0:nx, 0:ny] = xtmp[b]*self.beam[b]
ti = time()
r2c(self.xpad, axes=(0, 1), nthreads=self.nthreads,
forward=True, inorm=0, out=self.xhat)
tr2c += (time() - ti)
ti = time()
ne.evaluate('xhat * abspsf',
out=self.xhat,
local_dict={
'xhat': self.xhat,
'abspsf': self.abspsf[b]},
casting='unsafe')
# self.xhat *= self.abspsf[b]
tconv += (time() - ti)
ti = time()
c2r(self.xhat, axes=(0, 1), forward=False, out=self.xpad,
lastsize=self.ny_psf, inorm=2, nthreads=self.nthreads,
allow_overwriting_input=True)
tc2r += (time() - ti)
if self.beam[b] is None:
self.xout[b] = self.xpad[0:nx, 0:ny]
np.copyto(self.xout[b], self.xpad[0:nx, 0:ny])
else:
self.xout[b] = self.xpad[0:nx, 0:ny]*self.beam[b]
ne.evaluate('xout + xtmp * eta',
Expand All @@ -359,17 +340,82 @@ def dot(self, x):
'xtmp': xtmp,
'eta': self.eta[:, None, None]},
casting='unsafe')
ttot = time() - tii
ttally = (tr2c + tconv + tc2r)/ttot
tr2c /= ttot
tconv /= ttot
tc2r /= ttot
print('tr2c = ', tr2c)
print('tconv =', tconv)
print('tc2r = ', tc2r)
print('ttally = ', ttally)
print('ttot = ', time() - tii)
return self.xout

# def dot(self, x):
# if len(x.shape) == 3:
# xtmp = x
# elif len(x.shape) == 2:
# xtmp = x[None, :, :]
# else:
# raise ValueError("Unsupported number of input dimensions")

# nband, nx, ny = xtmp.shape
# assert nband == self.nband
# assert nx == self.nx
# assert ny == self.ny

# tii = time()
# ti = time()
# self.xpad.fill(0.0)
# tfill = time() - ti
# ti = time()
# if self.beam is None:
# np.copyto(self.xpad[:, 0:nx, 0:ny], xtmp)
# else:
# self.xpad[:, 0:nx, 0:ny] = xtmp*self.beam
# tpad = time() - ti
# ti = time()
# r2c(self.xpad, axes=(1, 2), nthreads=self.nthreads,
# forward=True, inorm=0, out=self.xhat)
# tr2c = time() - ti
# ti = time()
# ne.evaluate('xhat * abspsf',
# out=self.xhat,
# local_dict={
# 'xhat': self.xhat,
# 'abspsf': self.abspsf},
# casting='unsafe')
# tconv = time() - ti
# ti = time()
# c2r(self.xhat, axes=(1, 2), forward=False, out=self.xpad,
# lastsize=self.ny_psf, inorm=2, nthreads=self.nthreads,
# allow_overwriting_input=True)
# tc2r = time() - ti
# ti = time()
# if self.beam is None:
# np.copyto(self.xout, self.xpad[:, 0:nx, 0:ny])
# else:
# self.xout = self.xpad[:, 0:nx, 0:ny]*self.beam
# tcopy = time() - ti
# ti = time()
# ne.evaluate('xout + xtmp * eta',
# out=self.xout,
# local_dict={
# 'xout': self.xout,
# 'xtmp': xtmp,
# 'eta': self.eta[:, None, None]},
# casting='unsafe')
# tplus = time() - ti
# ttot = time() - tii
# tfill /= ttot
# tpad /= ttot
# tr2c /= ttot
# tconv /= ttot
# tc2r /= ttot
# tcopy /= ttot
# tplus /= ttot
# # print('tfill = ', tfill)
# # print('tpad = ', tpad)
# # print('tr2c = ', tr2c)
# # print('tconv = ', tconv)
# # print('tc2r = ', tc2r)
# # print('tcopy = ', tcopy)
# # print('tplus = ', tplus)
# print('ttot = ', time() - tii)
# return self.xout

def hdot(self, x):
# Hermitian operator
return self.dot(x)
Expand All @@ -388,7 +434,21 @@ def idot(self, x, mode='psf', x0=None):
assert ny == self.ny

if x0 is None:
x0 = (None,)*self.nband
x0 = np.zeros_like(xtmp)
for b in range(self.nband):
x0[b] = hess_direct_slice(xtmp,
xpad=self.xpad,
xhat=self.xhat,
xout=self.xout[b],
abspsf=self.abspsf[b],
taperxy=self.taperxy,
lastsize=self.ny_psf,
nthreads=self.nthreads,
eta=self.eta[b],
mode='backward')
if self.beam[b] is not None:
mask = (self.xout[b] > 0) & (self.beam[b] > self.min_beam)
self.xout[b, mask] /= self.beam[b, mask]**2

if mode=='direct':
for b in range(self.nband):
Expand Down

0 comments on commit 8d0d3e8

Please sign in to comment.