Skip to content

Commit

Permalink
Merge pull request #25 from arjunsavel/test_pca_opt
Browse files Browse the repository at this point in the history
multiprocessing!
  • Loading branch information
arjunsavel authored Feb 10, 2024
2 parents 49445dd + eb80c23 commit 00727d0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 14 deletions.
48 changes: 36 additions & 12 deletions src/cortecs/fit/fit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
The high-level API for fitting. Requires the Opac object.
"""

import warnings
from functools import partial
from multiprocessing import Pool
Expand Down Expand Up @@ -138,8 +137,15 @@ def fit_parallel(self):
:return:
"""
with warnings.catch_warnings():
num_processes = 3
func = partial(self.func_fit, self.P, self.T, **self.fitter_kwargs)
num_processes = 1

func = partial(
self.fit_func,
P=self.P,
T=self.T,
prep_res=self.prep_res,
**self.fitter_kwargs
)

self.pbar = tqdm(
total=len(self.wl),
Expand All @@ -150,18 +156,36 @@ def fit_parallel(self):
)

# these two lines are where the bulk of the multiprocessing happens
p = Pool(num_processes)

for i in range(len(self.wl)):
p.apply_async(
func,
args=(self.Opac.cross_section[:, :, i],),
callback=self.update_pbar,
pool = Pool(num_processes)

# actualy loop over using pool.map. need
# reformat the cross_section to be a list of 2D arrays
cross_section_reformatted = [
self.opac.cross_section[:, :, i] for i in range(len(self.wl))
]

# we tehcnically want sorted results. but apply async is the only way to get the progress bar to work!
async_result = []
for i, item in enumerate(cross_section_reformatted):
async_result.append(
[i, pool.apply_async(func, args=(item,), callback=self.update_pbar)]
)
p.close()
p.join()

# Close the pool
pool.close()
pool.join()

# Close the progress bar
self.pbar.close()

# sort the results based on the index
sorted_results = [None] * len(async_result)
for item in async_result:
i, res = item
sorted_results[i] = res.get()

self.fitter_results = [self.prep_res, sorted_results]

return

def save(self, savename):
Expand Down
4 changes: 2 additions & 2 deletions src/cortecs/fit/fit_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def do_pca(cube, nc=3):
return xMat, standardized_cube, s, vh, u


def fit_pca(cross_section, P, T, xMat, fit_axis="pressure", **kwargs):
def fit_pca(cross_section, P, T, prep_res, fit_axis="pressure", **kwargs):
"""
Fits the PCA to the opacity data.
Expand All @@ -120,7 +120,7 @@ def fit_pca(cross_section, P, T, xMat, fit_axis="pressure", **kwargs):
:beta: (nc x pixels) PCA coefficients
"""
# print("shapes for everything:", cross_section.shape, P.shape, T.shape, xMat.shape)

xMat = prep_res
cross_section = move_cross_section_axis(cross_section, fit_axis)
beta = fit_mlr(cross_section, xMat)
return beta
Expand Down
36 changes: 36 additions & 0 deletions src/cortecs/tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,39 @@ def test_nan_pca_cube_errors(self):
do_pca,
bad_cube,
)

def test_fit_parallel_zero_arg(self):
"""
I want to make sure that the parallel fitting function works.
:return:
"""
fitter_parallel = Fitter(self.opac, method="pca")
fitter_parallel.fit(parallel=True)

# check against serial
fitter_serial = Fitter(self.opac, method="pca")
fitter_serial.fit(parallel=False)

# are they the same?

np.testing.assert_almost_equal(
fitter_parallel.fitter_results[0], fitter_serial.fitter_results[0]
)

def test_fit_parallel_first_arg(self):
"""
I want to make sure that the parallel fitting function works.
:return:
"""
fitter_parallel = Fitter(self.opac, method="pca")
fitter_parallel.fit(parallel=True)

# check against serial
fitter_serial = Fitter(self.opac, method="pca")
fitter_serial.fit(parallel=False)

# are they the same?

np.testing.assert_almost_equal(
fitter_parallel.fitter_results[1], fitter_serial.fitter_results[1]
)

0 comments on commit 00727d0

Please sign in to comment.