diff --git a/src/cortecs/fit/fit.py b/src/cortecs/fit/fit.py index f8b8cdc..62301df 100644 --- a/src/cortecs/fit/fit.py +++ b/src/cortecs/fit/fit.py @@ -1,7 +1,6 @@ """ The high-level API for fitting. Requires the Opac object. """ - import warnings from functools import partial from multiprocessing import Pool @@ -138,8 +137,15 @@ def fit_parallel(self): :return: """ with warnings.catch_warnings(): - num_processes = 3 - func = partial(self.func_fit, self.P, self.T, **self.fitter_kwargs) + num_processes = 1 + + func = partial( + self.fit_func, + P=self.P, + T=self.T, + prep_res=self.prep_res, + **self.fitter_kwargs + ) self.pbar = tqdm( total=len(self.wl), @@ -150,18 +156,36 @@ def fit_parallel(self): ) # these two lines are where the bulk of the multiprocessing happens - p = Pool(num_processes) - - for i in range(len(self.wl)): - p.apply_async( - func, - args=(self.Opac.cross_section[:, :, i],), - callback=self.update_pbar, + pool = Pool(num_processes) + + # actualy loop over using pool.map. need + # reformat the cross_section to be a list of 2D arrays + cross_section_reformatted = [ + self.opac.cross_section[:, :, i] for i in range(len(self.wl)) + ] + + # we tehcnically want sorted results. but apply async is the only way to get the progress bar to work! + async_result = [] + for i, item in enumerate(cross_section_reformatted): + async_result.append( + [i, pool.apply_async(func, args=(item,), callback=self.update_pbar)] ) - p.close() - p.join() + + # Close the pool + pool.close() + pool.join() + + # Close the progress bar self.pbar.close() + # sort the results based on the index + sorted_results = [None] * len(async_result) + for item in async_result: + i, res = item + sorted_results[i] = res.get() + + self.fitter_results = [self.prep_res, sorted_results] + return def save(self, savename): diff --git a/src/cortecs/fit/fit_pca.py b/src/cortecs/fit/fit_pca.py index 4e3076b..93eaaaf 100644 --- a/src/cortecs/fit/fit_pca.py +++ b/src/cortecs/fit/fit_pca.py @@ -104,7 +104,7 @@ def do_pca(cube, nc=3): return xMat, standardized_cube, s, vh, u -def fit_pca(cross_section, P, T, xMat, fit_axis="pressure", **kwargs): +def fit_pca(cross_section, P, T, prep_res, fit_axis="pressure", **kwargs): """ Fits the PCA to the opacity data. @@ -120,7 +120,7 @@ def fit_pca(cross_section, P, T, xMat, fit_axis="pressure", **kwargs): :beta: (nc x pixels) PCA coefficients """ # print("shapes for everything:", cross_section.shape, P.shape, T.shape, xMat.shape) - + xMat = prep_res cross_section = move_cross_section_axis(cross_section, fit_axis) beta = fit_mlr(cross_section, xMat) return beta diff --git a/src/cortecs/tests/test_fit.py b/src/cortecs/tests/test_fit.py index 220b23b..edd52fd 100644 --- a/src/cortecs/tests/test_fit.py +++ b/src/cortecs/tests/test_fit.py @@ -103,3 +103,39 @@ def test_nan_pca_cube_errors(self): do_pca, bad_cube, ) + + def test_fit_parallel_zero_arg(self): + """ + I want to make sure that the parallel fitting function works. + :return: + """ + fitter_parallel = Fitter(self.opac, method="pca") + fitter_parallel.fit(parallel=True) + + # check against serial + fitter_serial = Fitter(self.opac, method="pca") + fitter_serial.fit(parallel=False) + + # are they the same? + + np.testing.assert_almost_equal( + fitter_parallel.fitter_results[0], fitter_serial.fitter_results[0] + ) + + def test_fit_parallel_first_arg(self): + """ + I want to make sure that the parallel fitting function works. + :return: + """ + fitter_parallel = Fitter(self.opac, method="pca") + fitter_parallel.fit(parallel=True) + + # check against serial + fitter_serial = Fitter(self.opac, method="pca") + fitter_serial.fit(parallel=False) + + # are they the same? + + np.testing.assert_almost_equal( + fitter_parallel.fitter_results[1], fitter_serial.fitter_results[1] + )