Skip to content

Commit

Permalink
unpack datasets sum_beam and sum_overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Nov 15, 2023
1 parent 76ea4d2 commit 2988084
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 46 deletions.
16 changes: 6 additions & 10 deletions pfb/utils/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,12 @@ def _eval_beam(beam_image, l_in, m_in, l_out, m_out):


def eval_beam(beam_image, l_in, m_in, l_out, m_out):
if lin.ndim == 2:
lout_dims = 'xy'
mout_dims = 'xy'
else:
lout_dims = 'x'
mout_dims = 'y'
nxo, nyo = l_out.shape
return da.blockwise(_eval_beam, 'xy',
beam_image, 'xy',
l_in, 'x',
m_in, 'y',
l_out, lout_dims,
m_out, mout_dims,
l_in, None,
m_in, None,
l_out, None,
m_out, None,
adjust_chunks={'x': nxo, 'y': nyo},
dtype=float)
159 changes: 125 additions & 34 deletions pfb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,6 @@ def _rephase_vis(vis, uvw, radec_in, radec_out):

# TODO - should allow coarsening to values other than 1
def concat_row(xds):
# TODO - how to compute average beam before we have access to grid?
times_in = []
freqs = []
for ds in xds:
Expand All @@ -770,44 +769,63 @@ def concat_row(xds):
for b in range(nband):
xdsb = []
times = []
freq_max = []
freq_min = []
time_max = []
time_min = []
nu = freqs[b]
for ds in xds:
if ds.freq_out == nu:
xdsb.append(ds)
times.append(ds.time_out)
freq_max.append(ds.freq_max)
freq_min.append(ds.freq_min)
time_max.append(ds.time_max)
time_min.append(ds.time_min)

wgt = [ds.WEIGHT for ds in xdsb]
vis = [ds.VIS for ds in xdsb]
mask = [ds.MASK for ds in xdsb]
uvw = [ds.UVW for ds in xdsb]

# get weighted sum of beams
beam = sum_beam(xdsb)
l_beam = xdsb[0].l_beam.data
m_beam = xdsb[0].m_beam.data

wgto = xr.concat(wgt, dim='row')
viso = xr.concat(vis, dim='row')
masko = xr.concat(mask, dim='row')
uvwo = xr.concat(uvw, dim='row')

xdso = xr.merge((wgto, viso, masko, uvwo))
xdso['BEAM'] = xdsb[0].BEAM # we need the grid to do this properly
xdso = xdso.assign({'BEAM': (('l_beam', 'm_beam'), beam)})
xdso['FREQ'] = xdsb[0].FREQ # is this always going to be the case?

xdso = xdso.chunk({'row':-1})
xdso = xdso.chunk({'row':-1, 'l_beam':-1, 'm_beam':-1})

xdso = xdso.assign_coords({
'chan': (('chan',), xdsb[0].chan.data)
'chan': (('chan',), xdsb[0].chan.data),
'l_beam': (('l_beam',), l_beam),
'm_beam': (('m_beam',), m_beam)
})

times = np.array(times)
freq_max = np.array(freq_max)
freq_min = np.array(freq_min)
time_max = np.array(time_max)
time_min = np.array(time_min)
tout = np.round(np.mean(times), 5) # avoid precision issues
xdso = xdso.assign_attrs({
'dec': xdsb[0].dec, # always the case?
'ra': xdsb[0].ra, # always the case?
'time_out': tout,
'time_max': times.max(),
'time_min': times.min(),
'time_max': time_max.max(),
'time_min': time_min.min(),
'timeid': 0,
'freq_out': nu,
'freq_max': xdsb[0].freq_max,
'freq_min': xdsb[0].freq_min,
'freq_max': freq_max.max(),
'freq_min': freq_min.min(),
})
xds_out.append(xdso)
return xds_out
Expand Down Expand Up @@ -855,86 +873,159 @@ def concat_chan(xds, nband_out=1):
freqsb = freqsb[freqsb <= fhigh]
else:
freqsb = freqsb[freqsb < fhigh]
time_max = []
time_min = []
for ds in xds:
# ds overlaps output if either ds.freq_min or ds.freq_max lies in the bin
low_in = ds.freq_min > flow and ds.freq_min < fhigh
high_in = ds.freq_max > flow and ds.freq_max < fhigh

if ds.time_out == time and (low_in or high_in):
xdst.append(ds)
time_max.append(ds.time_max)
time_min.append(ds.time_min)


# LB - we should be able to avoid this stack operation by using Jon's *() magic
wgt = da.stack([ds.WEIGHT.data for ds in xdst]).rechunk(-1, -1) # verify chunking over freq axis
vis = da.stack([ds.VIS.data for ds in xdst]).rechunk(-1, -1)
mask = da.stack([ds.MASK.data for ds in xdst]).rechunk(-1, -1)
freq = da.stack([ds.FREQ.data for ds in xdst]).rechunk(-1)
# wgt = da.stack([ds.WEIGHT.data for ds in xdst]).rechunk(-1, -1) # verify chunking over freq axis
# vis = da.stack([ds.VIS.data for ds in xdst]).rechunk(-1, -1)
# mask = da.stack([ds.MASK.data for ds in xdst]).rechunk(-1, -1)
# freq = da.stack([ds.FREQ.data for ds in xdst]).rechunk(-1)

nrow = xdst[0].row.size
nchan = freqsb.size

freqs_dask = da.from_array(freqsb, chunks=nchan)
blocker = Blocker(sum_overlap, 's')
blocker.add_input('vis', vis, 'src')
blocker.add_input('wgt', wgt, 'src')
blocker.add_input('mask', mask, 'src')
blocker.add_input('freq', freq, 'sc')
blocker = Blocker(sum_overlap, 'rc')
blocker.add_input('ufreq', freqs_dask, 'f')
blocker.add_input('flow', flow, None)
blocker.add_input('fhigh', fhigh, None)
blocker.add_output('viso', 'rf', ((nrow,), (nchan,)), vis.dtype)
blocker.add_output('wgto', 'rf', ((nrow,), (nchan,)), wgt.dtype)
blocker.add_output('masko', 'rf', ((nrow,), (nchan,)), mask.dtype)

for i, ds in enumerate(xdst):
ds = ds.chunk({'row':-1, 'chan':-1})
blocker.add_input(f'vis{i}', ds.VIS.data, 'rc')
blocker.add_input(f'wgt{i}', ds.WEIGHT.data, 'rc')
blocker.add_input(f'mask{i}', ds.MASK.data, 'rc')
blocker.add_input(f'freq{i}', ds.FREQ.data, 'c')

blocker.add_output('viso', 'rf', ((nrow,), (nchan,)), xdst[0].VIS.dtype)
blocker.add_output('wgto', 'rf', ((nrow,), (nchan,)), xdst[0].WEIGHT.dtype)
blocker.add_output('masko', 'rf', ((nrow,), (nchan,)), xdst[0].MASK.dtype)

out_dict = blocker.get_dask_outputs()

# get weighted sum of beam
beam = sum_beam(xdst)
l_beam = xdst[0].l_beam.data
m_beam = xdst[0].m_beam.data

data_vars = {
'VIS': (('row', 'chan'), out_dict['viso']),
'WEIGHT': (('row', 'chan'), out_dict['wgto']),
'MASK': (('row', 'chan'), out_dict['masko']),
'FREQ': (('chan',), freqs_dask),
'UVW': (('row', 'three'), xdst[0].UVW.data), # should be the same across data sets
'BEAM': (('scalar',), xdst[0].BEAM.data) # need to pass in the grid to do this properly
'BEAM': (('l_beam', 'm_beam'), beam)
}

coords = {
'chan': (('chan',), freqsb)
'chan': (('chan',), freqsb),
'l_beam': (('l_beam',), l_beam),
'm_beam': (('m_beam',), m_beam)
}

fout = np.round(bin_centers[b], 5) # avoid precision issues
time_max = np.array(time_max)
time_min = np.array(time_min)
attrs = {
'freq_out': fout,
'freq_max': fhigh,
'freq_min': flow,
'bandid': b,
'dec': xdst[0].dec,
'ra': xdst[0].ra,
'time_out': time
'time_out': time,
'time_max': time_max.max(),
'time_min': time_min.min()
}

xdso = xr.Dataset(data_vars=data_vars,
coords=coords,
attrs=attrs)

xds_out.append(xdso)
return xds_out


def sum_overlap(vis, wgt, mask, freq, ufreq, flow, fhigh):
nds = vis.shape[0]
def sum_beam(xds):
'''
Compute the weighted sum of the beams contained in xds
weighting by the sum of the weights in each ds
'''
nx, ny = xds[0].BEAM.shape
btype = xds[0].BEAM.dtype
blocker = Blocker(_sum_beam, 'xy')
blocker.add_input('nx', nx, None)
blocker.add_input('ny', ny, None)
blocker.add_input('btype', btype, None)
for i, ds in enumerate(xds):
blocker.add_input(f'beam{i}', ds.BEAM.data, 'xy')
blocker.add_input(f'wgt{i}', ds.WEIGHT.data, 'rf')

blocker.add_output('beam', 'xy', ((nx,),(ny,)), btype)
out_dict = blocker.get_dask_outputs()
return out_dict['beam']

def _sum_beam(nx, ny, btype, **kwargs):
beam = np.zeros((nx, ny), dtype=btype)
# need to separate the different variables in kwargs
# i.e. beam, wgt -> nvars=2
nitems = len(kwargs)//2
wsum = 0.0
for i in range(nitems):
wgti = kwargs[f'wgt{i}']
wsumi = wgti.sum()
beam += wsumi * kwargs[f'beam{i}']
wsum += wsumi

if wsum:
beam /= wsum

# blocker expects dict as output
out_dict = {}
out_dict['beam'] = beam

return out_dict

def sum_overlap(ufreq, flow, fhigh, **kwargs):
# need to separate the different variables in kwargs
# i.e. vis, wgt, mask, freq -> nvars=4
nitems = len(kwargs)//4

# output grids
nchan = ufreq.size
nrow = vis.shape[1]
viso = np.zeros((nrow, nchan), dtype=vis.dtype)
wgto = np.zeros((nrow, nchan), dtype=wgt.dtype)
masko = np.zeros((nrow, nchan), dtype=mask.dtype)
nrow = kwargs['vis0'].shape[0]
viso = np.zeros((nrow, nchan), dtype=kwargs['vis0'].dtype)
wgto = np.zeros((nrow, nchan), dtype=kwargs['wgt0'].dtype)
masko = np.zeros((nrow, nchan), dtype=kwargs['mask0'].dtype)

# weighted sum at overlap
for i in range(nds):
nu = freq[i]
for i in range(nitems):
vis = kwargs[f'vis{i}']
wgt = kwargs[f'wgt{i}']
mask = kwargs[f'mask{i}']
nu = kwargs[f'freq{i}']
_, idx0, idx1 = np.intersect1d(nu, ufreq, assume_unique=True, return_indices=True)
viso[:, idx1] += vis[i][:, idx0] * wgt[i][:, idx0] * mask[i][:, idx0]
wgto[:, idx1] += wgt[i][:, idx0] * mask[i][:, idx0]
masko[:, idx1] += mask[i][:, idx0]
viso[:, idx1] += vis[:, idx0] * wgt[:, idx0] * mask[:, idx0]
wgto[:, idx1] += wgt[:, idx0] * mask[:, idx0]
masko[:, idx1] += mask[:, idx0]

# unmasked where at least one data point is unflagged
masko = np.where(masko > 0, True, False)
# TODO - why does this get trigerred?
# if (wgto[masko]==0).any():
# print(np.where(wgto[masko]==0))
# raise ValueError("Weights are zero at unflagged location")
viso[masko] = viso[masko]/wgto[masko]

# blocker expects a dictionary as output
Expand Down
4 changes: 3 additions & 1 deletion pfb/workers/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ def _grid(xdsi=None, **kw):
l = (-(nx//2) + da.arange(nx)) * cell_deg + np.deg2rad(x0)
m = (-(ny//2) + da.arange(ny)) * cell_deg + np.deg2rad(y0)
ll, mm = da.meshgrid(l, m, indexing='ij')
bvals = eval_beam(ds.BEAM.data, ll, mm)
l_beam = ds.l_beam.data
m_beam = ds.m_beam.data
bvals = eval_beam(ds.BEAM.data, l_beam, m_beam, ll, mm)
out_ds = out_ds.assign(**{'BEAM': (('x', 'y'), bvals)})

# get the model
Expand Down
2 changes: 1 addition & 1 deletion pfb/workers/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def init(**kw):
gtlist = gainstore.fs.glob(gt.rstrip('/'))
try:
assert len(gtlist) > 0
gainnames.append(*list(map(gainstore.fs.unstrip_protocol, gt)))
gainnames.append(*list(map(gainstore.fs.unstrip_protocol, gtlist)))
except Exception as e:
raise ValueError(f"No gain table at {gt}")
opts.gain_table = gainnames
Expand Down

0 comments on commit 2988084

Please sign in to comment.