From bd68b5c213b28669bba769fb86f94936180b6ffc Mon Sep 17 00:00:00 2001 From: landmanbester Date: Wed, 4 Sep 2024 14:16:12 +0200 Subject: [PATCH] use built in numba.set_num_threads to set the number of threads to use instead of manual version --- pfb/operators/gridder.py | 42 ++++++++++++++++++++-------------------- pfb/operators/psi.py | 9 ++++----- pfb/utils/misc.py | 13 +------------ pfb/workers/grid.py | 20 ++++++------------- pfb/workers/spotless.py | 20 ++++++------------- 5 files changed, 38 insertions(+), 66 deletions(-) diff --git a/pfb/operators/gridder.py b/pfb/operators/gridder.py index b5b905ef..52fd26d5 100644 --- a/pfb/operators/gridder.py +++ b/pfb/operators/gridder.py @@ -17,7 +17,7 @@ from pfb.utils.weighting import counts_to_weights, _compute_counts from pfb.utils.beam import eval_beam from pfb.utils.naming import xds_from_list -from pfb.utils.misc import fitcleanbeam, numba_threads +from pfb.utils.misc import fitcleanbeam iFs = np.fft.ifftshift Fs = np.fft.fftshift @@ -431,26 +431,26 @@ def image_data_products(dsl, # we usually want to re-evaluate this since the robustness may change if robustness is not None: - with numba_threads(nthreads): - counts = _compute_counts(uvw, - freq, - mask, - wgt, - nx, ny, - cellx, celly, - uvw.dtype, - ngrid=np.minimum(nthreads, 8), # limit number of grids - usign=1.0 if flip_u else -1.0, - vsign=1.0 if flip_v else -1.0) - imwgt = counts_to_weights( - counts, - uvw, - freq, - nx, ny, - cellx, celly, - robustness, - usign=1.0 if flip_u else -1.0, - vsign=1.0 if flip_v else -1.0) + numba.set_num_threads(nthreads) + counts = _compute_counts(uvw, + freq, + mask, + wgt, + nx, ny, + cellx, celly, + uvw.dtype, + ngrid=np.minimum(nthreads, 8), # limit number of grids + usign=1.0 if flip_u else -1.0, + vsign=1.0 if flip_v else -1.0) + imwgt = counts_to_weights( + counts, + uvw, + freq, + nx, ny, + cellx, celly, + robustness, + usign=1.0 if flip_u else -1.0, + vsign=1.0 if flip_v else -1.0) if wgt is not None: wgt *= imwgt else: diff --git a/pfb/operators/psi.py b/pfb/operators/psi.py index 2661850f..e778bf24 100644 --- a/pfb/operators/psi.py +++ b/pfb/operators/psi.py @@ -8,7 +8,6 @@ from scipy.datasets import ascent from pfb.wavelets import coeff_size, signal_size, dwt2d, idwt2d, copyT from time import time -from pfb.utils.misc import numba_threads @numba.njit @@ -261,14 +260,14 @@ def hdot(self, alpha, xo): def psi_dot_impl(x, alphao, psib, b, nthreads=1): - with numba_threads(nthreads): - psib.dot(x, alphao) + numba.set_num_threads(nthreads) + psib.dot(x, alphao) return b def psi_hdot_impl(alpha, xo, psib, b, nthreads=1): - with numba_threads(nthreads): - psib.hdot(alpha, xo) + numba.set_num_threads(nthreads) + psib.hdot(alpha, xo) return b diff --git a/pfb/utils/misc.py b/pfb/utils/misc.py index 040349c8..688d50c4 100644 --- a/pfb/utils/misc.py +++ b/pfb/utils/misc.py @@ -1,5 +1,5 @@ import sys -from contextlib import contextmanager, nullcontext +from contextlib import nullcontext import numpy as np import numexpr as ne import numba @@ -52,17 +52,6 @@ def interaction(self, *args, **kwargs): sys.stdin = _stdin -@contextmanager -def numba_threads(n): - """Context manager for controlling Numba's number of threads.""" - old_value = numba.config.NUMBA_NUM_THREADS - numba.config.NUMBA_NUM_THREADS = n - try: - yield - finally: - numba.config.NUMBA_NUM_THREADS = old_value - - def compute_context(scheduler, output_filename, boring=True): if scheduler == "distributed": return performance_report(filename=output_filename + "_dask_report.html") diff --git a/pfb/workers/grid.py b/pfb/workers/grid.py index 4944a196..aba86369 100644 --- a/pfb/workers/grid.py +++ b/pfb/workers/grid.py @@ -171,7 +171,7 @@ def grid(**kw): except Exception as e: raise e -def _grid(xdsi=None, **kw): +def _grid(**kw): opts = OmegaConf.create(kw) OmegaConf.set_struct(opts, True) @@ -204,20 +204,12 @@ def _grid(xdsi=None, **kw): # xds contains vis products, no imaging weights applied xds_name = f'{basename}.xds' if opts.xds is None else opts.xds xds_store = DaskMSStore(xds_name) - if xdsi is not None: - xds = [] - for ds in xdsi: - xds.append(ds.chunk({'row':-1, - 'chan': -1, - 'l_beam': -1, - 'm_beam': -1})) - else: - try: - assert xds_store.exists() - except Exception as e: - raise ValueError(f"There must be a dataset at {xds_store.url}") + try: + assert xds_store.exists() + except Exception as e: + raise ValueError(f"There must be a dataset at {xds_store.url}") - xds = xds_from_url(xds_name) + xds = xds_from_url(xds_name) times_in = [] freqs_in = [] diff --git a/pfb/workers/spotless.py b/pfb/workers/spotless.py index 0a96af3b..114f62f3 100644 --- a/pfb/workers/spotless.py +++ b/pfb/workers/spotless.py @@ -125,7 +125,7 @@ def spotless(**kw): print(f"All done after {time.time() - ti}s.", file=log) -def _spotless(xdsi=None, **kw): +def _spotless(**kw): ''' Distributed spotless algorithm. @@ -185,20 +185,12 @@ def _spotless(xdsi=None, **kw): # xds contains vis products, no imaging weights applied xds_name = f'{basename}.xds' if opts.xds is None else opts.xds xds_store = DaskMSStore(xds_name) - if xdsi is not None: - xds = [] - for ds in xdsi: - xds.append(ds.chunk({'row':-1, - 'chan': -1, - 'l_beam': -1, - 'm_beam': -1})) - else: - try: - assert xds_store.exists() - except Exception as e: - raise ValueError(f"There must be a dataset at {xds_store.url}") + try: + assert xds_store.exists() + except Exception as e: + raise ValueError(f"There must be a dataset at {xds_store.url}") - xds = xds_from_url(xds_name) + xds = xds_from_url(xds_name) # TODO - how to glob with protocol in tact? xds_list = xds_store.fs.glob(f'{xds_store.url}/*')