Skip to content

Commit

Permalink
use built in numba.set_num_threads to set the number of threads to us…
Browse files Browse the repository at this point in the history
…e instead of manual version
  • Loading branch information
landmanbester committed Sep 4, 2024
1 parent 300b275 commit bd68b5c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 66 deletions.
42 changes: 21 additions & 21 deletions pfb/operators/gridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pfb.utils.weighting import counts_to_weights, _compute_counts
from pfb.utils.beam import eval_beam
from pfb.utils.naming import xds_from_list
from pfb.utils.misc import fitcleanbeam, numba_threads
from pfb.utils.misc import fitcleanbeam
iFs = np.fft.ifftshift
Fs = np.fft.fftshift

Expand Down Expand Up @@ -431,26 +431,26 @@ def image_data_products(dsl,

# we usually want to re-evaluate this since the robustness may change
if robustness is not None:
with numba_threads(nthreads):
counts = _compute_counts(uvw,
freq,
mask,
wgt,
nx, ny,
cellx, celly,
uvw.dtype,
ngrid=np.minimum(nthreads, 8), # limit number of grids
usign=1.0 if flip_u else -1.0,
vsign=1.0 if flip_v else -1.0)
imwgt = counts_to_weights(
counts,
uvw,
freq,
nx, ny,
cellx, celly,
robustness,
usign=1.0 if flip_u else -1.0,
vsign=1.0 if flip_v else -1.0)
numba.set_num_threads(nthreads)
counts = _compute_counts(uvw,
freq,
mask,
wgt,
nx, ny,
cellx, celly,
uvw.dtype,
ngrid=np.minimum(nthreads, 8), # limit number of grids
usign=1.0 if flip_u else -1.0,
vsign=1.0 if flip_v else -1.0)
imwgt = counts_to_weights(
counts,
uvw,
freq,
nx, ny,
cellx, celly,
robustness,
usign=1.0 if flip_u else -1.0,
vsign=1.0 if flip_v else -1.0)
if wgt is not None:
wgt *= imwgt
else:
Expand Down
9 changes: 4 additions & 5 deletions pfb/operators/psi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from scipy.datasets import ascent
from pfb.wavelets import coeff_size, signal_size, dwt2d, idwt2d, copyT
from time import time
from pfb.utils.misc import numba_threads


@numba.njit
Expand Down Expand Up @@ -261,14 +260,14 @@ def hdot(self, alpha, xo):


def psi_dot_impl(x, alphao, psib, b, nthreads=1):
with numba_threads(nthreads):
psib.dot(x, alphao)
numba.set_num_threads(nthreads)
psib.dot(x, alphao)
return b


def psi_hdot_impl(alpha, xo, psib, b, nthreads=1):
with numba_threads(nthreads):
psib.hdot(alpha, xo)
numba.set_num_threads(nthreads)
psib.hdot(alpha, xo)
return b


Expand Down
13 changes: 1 addition & 12 deletions pfb/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from contextlib import contextmanager, nullcontext
from contextlib import nullcontext
import numpy as np
import numexpr as ne
import numba
Expand Down Expand Up @@ -52,17 +52,6 @@ def interaction(self, *args, **kwargs):
sys.stdin = _stdin


@contextmanager
def numba_threads(n):
"""Context manager for controlling Numba's number of threads."""
old_value = numba.config.NUMBA_NUM_THREADS
numba.config.NUMBA_NUM_THREADS = n
try:
yield
finally:
numba.config.NUMBA_NUM_THREADS = old_value


def compute_context(scheduler, output_filename, boring=True):
if scheduler == "distributed":
return performance_report(filename=output_filename + "_dask_report.html")
Expand Down
20 changes: 6 additions & 14 deletions pfb/workers/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def grid(**kw):
except Exception as e:
raise e

def _grid(xdsi=None, **kw):
def _grid(**kw):
opts = OmegaConf.create(kw)
OmegaConf.set_struct(opts, True)

Expand Down Expand Up @@ -204,20 +204,12 @@ def _grid(xdsi=None, **kw):
# xds contains vis products, no imaging weights applied
xds_name = f'{basename}.xds' if opts.xds is None else opts.xds
xds_store = DaskMSStore(xds_name)
if xdsi is not None:
xds = []
for ds in xdsi:
xds.append(ds.chunk({'row':-1,
'chan': -1,
'l_beam': -1,
'm_beam': -1}))
else:
try:
assert xds_store.exists()
except Exception as e:
raise ValueError(f"There must be a dataset at {xds_store.url}")
try:
assert xds_store.exists()
except Exception as e:
raise ValueError(f"There must be a dataset at {xds_store.url}")

xds = xds_from_url(xds_name)
xds = xds_from_url(xds_name)

times_in = []
freqs_in = []
Expand Down
20 changes: 6 additions & 14 deletions pfb/workers/spotless.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def spotless(**kw):
print(f"All done after {time.time() - ti}s.", file=log)


def _spotless(xdsi=None, **kw):
def _spotless(**kw):
'''
Distributed spotless algorithm.
Expand Down Expand Up @@ -185,20 +185,12 @@ def _spotless(xdsi=None, **kw):
# xds contains vis products, no imaging weights applied
xds_name = f'{basename}.xds' if opts.xds is None else opts.xds
xds_store = DaskMSStore(xds_name)
if xdsi is not None:
xds = []
for ds in xdsi:
xds.append(ds.chunk({'row':-1,
'chan': -1,
'l_beam': -1,
'm_beam': -1}))
else:
try:
assert xds_store.exists()
except Exception as e:
raise ValueError(f"There must be a dataset at {xds_store.url}")
try:
assert xds_store.exists()
except Exception as e:
raise ValueError(f"There must be a dataset at {xds_store.url}")

xds = xds_from_url(xds_name)
xds = xds_from_url(xds_name)

# TODO - how to glob with protocol in tact?
xds_list = xds_store.fs.glob(f'{xds_store.url}/*')
Expand Down

0 comments on commit bd68b5c

Please sign in to comment.