Skip to content

Commit

Permalink
remove QuartiCal dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Sep 4, 2024
1 parent bd68b5c commit 1ff7673
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 331 deletions.
5 changes: 2 additions & 3 deletions pfb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def set_client(nworkers, log, stack=None, host_address=None,
if client_log_level == 'error':
logging.getLogger("distributed").setLevel(logging.ERROR)
logging.getLogger("bokeh").setLevel(logging.ERROR)
logging.getLogger("tornado").setLevel(logging.ERROR)
logging.getLogger("tornado").setLevel(logging.CRITICAL)
elif client_log_level == 'warning':
logging.getLogger("distributed").setLevel(logging.WARNING)
logging.getLogger("bokeh").setLevel(logging.WARNING)
Expand All @@ -55,7 +55,6 @@ def set_client(nworkers, log, stack=None, host_address=None,
logging.getLogger("tornado").setLevel(logging.DEBUG)

import dask
dask.config.set({'distributed.comm.compression': 'lz4'})
# set up client
host_address = host_address or os.environ.get("DASK_SCHEDULER_ADDRESS")
if host_address is not None:
Expand All @@ -68,7 +67,7 @@ def set_client(nworkers, log, stack=None, host_address=None,
dask.config.set({
'distributed.comm.compression': {
'on': True,
'type': 'blosc'
'type': 'lz4'
}
})
cluster = LocalCluster(processes=True,
Expand Down
1 change: 0 additions & 1 deletion pfb/operators/gridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ducc0.wgridder.experimental import vis2dirty, dirty2vis
from ducc0.fft import c2r, r2c, c2c
from africanus.constants import c as lightspeed
from quartical.utils.dask import Blocker
from pfb.utils.weighting import counts_to_weights, _compute_counts
from pfb.utils.beam import eval_beam
from pfb.utils.naming import xds_from_list
Expand Down
2 changes: 1 addition & 1 deletion pfb/operators/psi.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def __init__(self, nband, nx, ny, bases, nlevel, nthreads):
self.Nxmax = self.psib[0].Nxmax
self.Nymax = self.psib[0].Nymax

self.nthreads_per_band = nthreads//nband
self.nthreads_per_band = np.maximum(1, nthreads//nband)

def dot(self, x, alphao):
'''
Expand Down
1 change: 0 additions & 1 deletion pfb/utils/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dask.graph_manipulation import clone
import dask.array as da
from xarray import Dataset
from quartical.utils.numba import coerce_literal
from operator import getitem
from pfb.utils.beam import interp_beam

Expand Down
305 changes: 0 additions & 305 deletions pfb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from collections import namedtuple
from africanus.coordinates.coordinates import radec_to_lmn
import xarray as xr
from quartical.utils.dask import Blocker
from scipy.interpolate import RegularGridInterpolator
from scipy.linalg import solve_triangular
import sympy as sm
Expand Down Expand Up @@ -763,310 +762,6 @@ def chunkify_rows(time, utimes_per_chunk, daskify_idx=False):
return tuple(row_chunks), time_bin_indices, time_bin_counts


def rephase_vis(vis, uvw, radec_in, radec_out):
return da.blockwise(_rephase_vis, 'rf',
vis, 'rf',
uvw, 'r3',
radec_in, None,
radec_out, None,
dtype=vis.dtype)

def _rephase_vis(vis, uvw, radec_in, radec_out):
l_in, m_in, n_in = radec_to_lmn(radec_in)
l_out, m_out, n_out = radec_to_lmn(radec_out)
return vis * np.exp(1j*(uvw[:, 0]*(l_out-l_in) +
uvw[:, 1]*(m_out-m_in) +
uvw[:, 2]*(n_out-n_in)))


# TODO - should allow coarsening to values other than 1
def concat_row(xds):
times_in = []
freqs = []
for ds in xds:
times_in.append(ds.time_out)
freqs.append(ds.freq_out)

times_in = np.unique(times_in)
freqs = np.unique(freqs)

nband = freqs.size
ntime_in = times_in.size

if ntime_in == 1: # no need to concatenate
return xds

# do merge manually because different variables require different
# treatment anyway eg. the BEAM should be computed as a weighted sum
xds_out = []
for b in range(nband):
xdsb = []
times = []
freq_max = []
freq_min = []
time_max = []
time_min = []
nu = freqs[b]
for ds in xds:
if ds.freq_out == nu:
xdsb.append(ds)
times.append(ds.time_out)
freq_max.append(ds.freq_max)
freq_min.append(ds.freq_min)
time_max.append(ds.time_max)
time_min.append(ds.time_min)

wgt = [ds.WEIGHT for ds in xdsb]
vis = [ds.VIS for ds in xdsb]
mask = [ds.MASK for ds in xdsb]
uvw = [ds.UVW for ds in xdsb]

# get weighted sum of beams
beam = sum_beam(xdsb)
l_beam = xdsb[0].l_beam.data
m_beam = xdsb[0].m_beam.data

wgto = xr.concat(wgt, dim='row')
viso = xr.concat(vis, dim='row')
masko = xr.concat(mask, dim='row')
uvwo = xr.concat(uvw, dim='row')

xdso = xr.merge((wgto, viso, masko, uvwo))
xdso = xdso.assign({'BEAM': (('l_beam', 'm_beam'), beam)})
xdso['FREQ'] = xdsb[0].FREQ # is this always going to be the case?

xdso = xdso.chunk({'row':-1, 'l_beam':-1, 'm_beam':-1})

xdso = xdso.assign_coords({
'chan': (('chan',), xdsb[0].chan.data),
'l_beam': (('l_beam',), l_beam),
'm_beam': (('m_beam',), m_beam)
})

times = np.array(times)
freq_max = np.array(freq_max)
freq_min = np.array(freq_min)
time_max = np.array(time_max)
time_min = np.array(time_min)
tout = np.round(np.mean(times), 5) # avoid precision issues
xdso = xdso.assign_attrs({
'dec': xdsb[0].dec, # always the case?
'ra': xdsb[0].ra, # always the case?
'time_out': tout,
'time_max': time_max.max(),
'time_min': time_min.min(),
'timeid': 0,
'freq_out': nu,
'freq_max': freq_max.max(),
'freq_min': freq_min.min(),
})
xds_out.append(xdso)
return xds_out


def concat_chan(xds, nband_out=1):
times = []
freqs_in = []
freqs_min = []
freqs_max = []
all_freqs = []
for ds in xds:
times.append(ds.time_out)
freqs_in.append(ds.freq_out)
freqs_min.append(ds.freq_min)
freqs_max.append(ds.freq_max)
all_freqs.append(ds.chan)

times = np.unique(times)
freqs_in = np.unique(freqs_in)
freqs_min = np.unique(freqs_min)
freqs_max = np.unique(freqs_max)
all_freqs = np.unique(np.concatenate(all_freqs))

nband_in = freqs_in.size
ntime = times.size

if nband_in == nband_out or nband_in == 1: # no need to concatenate
return xds

# currently assuming linearly spaced frequencies
freq_bins = np.linspace(freqs_min.min(), freqs_max.max(), nband_out+1)
bin_centers = (freq_bins[1:] + freq_bins[0:-1])/2

xds_out = []
for t in range(ntime):
time = times[t]
for b in range(nband_out):
xdst = []
flow = freq_bins[b]
fhigh = freq_bins[b+1]
freqsb = all_freqs[all_freqs >= flow]
# exclusive except for the last one
if b==nband_out-1:
freqsb = freqsb[freqsb <= fhigh]
else:
freqsb = freqsb[freqsb < fhigh]
time_max = []
time_min = []
for ds in xds:
# ds overlaps output if either ds.freq_min or ds.freq_max lies in the bin
low_in = ds.freq_min > flow and ds.freq_min < fhigh
high_in = ds.freq_max > flow and ds.freq_max < fhigh

if ds.time_out == time and (low_in or high_in):
xdst.append(ds)
time_max.append(ds.time_max)
time_min.append(ds.time_min)

nrow = xdst[0].row.size
nchan = freqsb.size

freqs_dask = da.from_array(freqsb, chunks=nchan)
blocker = Blocker(sum_overlap, 'rc')
blocker.add_input('ufreq', freqs_dask, 'f')
blocker.add_input('flow', flow, None)
blocker.add_input('fhigh', fhigh, None)

for i, ds in enumerate(xdst):
ds = ds.chunk({'row':-1, 'chan':-1})
blocker.add_input(f'vis{i}', ds.VIS.data, 'rc')
blocker.add_input(f'wgt{i}', ds.WEIGHT.data, 'rc')
blocker.add_input(f'mask{i}', ds.MASK.data, 'rc')
blocker.add_input(f'freq{i}', ds.FREQ.data, 'c')

blocker.add_output('viso', 'rf', ((nrow,), (nchan,)), xdst[0].VIS.dtype)
blocker.add_output('wgto', 'rf', ((nrow,), (nchan,)), xdst[0].WEIGHT.dtype)
blocker.add_output('masko', 'rf', ((nrow,), (nchan,)), xdst[0].MASK.dtype)

out_dict = blocker.get_dask_outputs()

# get weighted sum of beam
beam = sum_beam(xdst)
l_beam = xdst[0].l_beam.data
m_beam = xdst[0].m_beam.data

data_vars = {
'VIS': (('row', 'chan'), out_dict['viso']),
'WEIGHT': (('row', 'chan'), out_dict['wgto']),
'MASK': (('row', 'chan'), out_dict['masko']),
'FREQ': (('chan',), freqs_dask),
'UVW': (('row', 'three'), xdst[0].UVW.data), # should be the same across data sets
'BEAM': (('l_beam', 'm_beam'), beam)
}

coords = {
'chan': (('chan',), freqsb),
'l_beam': (('l_beam',), l_beam),
'm_beam': (('m_beam',), m_beam)
}

fout = np.round(bin_centers[b], 5) # avoid precision issues
time_max = np.array(time_max)
time_min = np.array(time_min)
attrs = {
'freq_out': fout,
'freq_max': fhigh,
'freq_min': flow,
'bandid': b,
'dec': xdst[0].dec,
'ra': xdst[0].ra,
'time_out': time,
'time_max': time_max.max(),
'time_min': time_min.min()
}

xdso = xr.Dataset(data_vars=data_vars,
coords=coords,
attrs=attrs)

xds_out.append(xdso)
return xds_out


def sum_beam(xds):
'''
Compute the weighted sum of the beams contained in xds
weighting by the sum of the weights in each ds
'''
nx, ny = xds[0].BEAM.shape
btype = xds[0].BEAM.dtype
blocker = Blocker(_sum_beam, 'xy')
blocker.add_input('nx', nx, None)
blocker.add_input('ny', ny, None)
blocker.add_input('btype', btype, None)
for i, ds in enumerate(xds):
blocker.add_input(f'beam{i}', ds.BEAM.data, 'xy')
blocker.add_input(f'wgt{i}', ds.WEIGHT.data, 'rf')

blocker.add_output('beam', 'xy', ((nx,),(ny,)), btype)
out_dict = blocker.get_dask_outputs()
return out_dict['beam']

def _sum_beam(nx, ny, btype, **kwargs):
beam = np.zeros((nx, ny), dtype=btype)
# need to separate the different variables in kwargs
# i.e. beam, wgt -> nvars=2
nitems = len(kwargs)//2
wsum = 0.0
for i in range(nitems):
wgti = kwargs[f'wgt{i}']
wsumi = wgti.sum()
beam += wsumi * kwargs[f'beam{i}']
wsum += wsumi

if wsum:
beam /= wsum

# blocker expects dict as output
out_dict = {}
out_dict['beam'] = beam

return out_dict

def sum_overlap(ufreq, flow, fhigh, **kwargs):
# need to separate the different variables in kwargs
# i.e. vis, wgt, mask, freq -> nvars=4
nitems = len(kwargs)//4

# output grids
nchan = ufreq.size
nrow = kwargs['vis0'].shape[0]
viso = np.zeros((nrow, nchan), dtype=kwargs['vis0'].dtype)
wgto = np.zeros((nrow, nchan), dtype=kwargs['wgt0'].dtype)
masko = np.zeros((nrow, nchan), dtype=kwargs['mask0'].dtype)

# weighted sum at overlap
for i in range(nitems):
vis = kwargs[f'vis{i}']
wgt = kwargs[f'wgt{i}']
mask = kwargs[f'mask{i}']
nu = kwargs[f'freq{i}']
_, idx0, idx1 = np.intersect1d(nu, ufreq, assume_unique=True, return_indices=True)
try:
viso[:, idx1] += vis[:, idx0] * wgt[:, idx0] * mask[:, idx0]
wgto[:, idx1] += wgt[:, idx0] * mask[:, idx0]
masko[:, idx1] += mask[:, idx0]
except Exception as e:
print(flow, fhigh, ufreq, nu)
raise e

# unmasked where at least one data point is unflagged
masko = np.where(masko > 0, True, False)
# TODO - why does this get trigerred?
# if (wgto[masko]==0).any():
# print(np.where(wgto[masko]==0))
# raise ValueError("Weights are zero at unflagged location")
viso[masko] = viso[masko]/wgto[masko]

# blocker expects a dictionary as output
out_dict = {}
out_dict['viso'] = viso
out_dict['wgto'] = wgto
out_dict['masko'] = masko.astype(np.uint8)

return out_dict


def l1reweight_func(model,
psiH=None,
outvar=None,
Expand Down
2 changes: 0 additions & 2 deletions pfb/utils/stokes2vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from distributed import worker_client
import dask.array as da
from xarray import Dataset
# from quartical.utils.numba import coerce_literal
from operator import getitem
from pfb.utils.beam import interp_beam
from pfb.utils.misc import weight_from_sigma, combine_columns
import dask
from quartical.utils.dask import Blocker
from pfb.utils.stokes import stokes_funcs
from pfb.utils.weighting import weight_data
from uuid import uuid4
Expand Down
1 change: 0 additions & 1 deletion pfb/utils/weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import dask.array as da
from ducc0.fft import c2c
from africanus.constants import c as lightspeed
from quartical.utils.dask import Blocker
from pfb.utils.misc import JIT_OPTIONS
from pfb.utils.stokes import stokes_funcs
from pfb.utils.naming import xds_from_list
Expand Down
Loading

0 comments on commit 1ff7673

Please sign in to comment.