Skip to content

Commit

Permalink
Batch ranks for parquet packets and refactor plotting (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeshingles authored Apr 23, 2024
1 parent b6cbfa3 commit 515f56a
Show file tree
Hide file tree
Showing 31 changed files with 519 additions and 513 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ repos:
# - id: yamlfmt

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.7
rev: v0.4.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.7
rev: v0.4.1
hooks:
- id: ruff-format

Expand Down
1 change: 1 addition & 0 deletions artistools/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"describeinputmodel": ("inputmodel.describeinputmodel", "main"),
"exportmassfractions": ("estimators.exportmassfractions", "main"),
"getpath": ("", "get_path"),
"lc": ("lightcurve.plotlightcurve", "main"),
"listtimesteps": ("", "showtimesteptimes"),
"makeartismodelfromparticlegridmap": ("inputmodel.modelfromhydro", "main"),
"maptogrid": ("inputmodel.maptogrid", "main"),
Expand Down
67 changes: 24 additions & 43 deletions artistools/estimators/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import argparse
import contextlib
import itertools
import math
import multiprocessing
import multiprocessing.pool
import sys
import time
import typing as t
Expand Down Expand Up @@ -49,17 +49,16 @@ def get_variableunits(key: str) -> str | None:


def get_variablelongunits(key: str) -> str | None:
variablelongunits = {
return {
"heating_dep/total_dep": "",
"TR": "Temperature [K]",
"Te": "Temperature [K]",
"TJ": "Temperature [K]",
}
return variablelongunits.get(key)
}.get(key)


def get_varname_formatted(varname: str) -> str:
replacements = {
return {
"nne": r"n$_{\rm e}$",
"lognne": r"Log n$_{\rm e}$",
"rho": r"$\rho$",
Expand All @@ -70,8 +69,7 @@ def get_varname_formatted(varname: str) -> str:
"gamma_R_bfest": r"$\Gamma_{\rm phot}$ [s$^{-1}$]",
"heating_dep/total_dep": "Heating fraction",
**{f"vel_{ax}_mid_on_c": f"$v_{{{ax}}}$" for ax in ["x", "y", "z", "r", "rcyl"]},
}
return replacements.get(varname, varname)
}.get(varname, varname)


def apply_filters(
Expand Down Expand Up @@ -225,22 +223,6 @@ def read_estimators_from_file(
)


def batched(iterable, n): # -> Generator[list, Any, None]:
"""Batch data into iterators of length n. The last batch may be shorter."""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
msg = "n must be at least one"
raise ValueError(msg)
it = iter(iterable)
while True:
chunk_it = itertools.islice(it, n)
try:
first_el = next(chunk_it)
except StopIteration:
return
yield list(itertools.chain((first_el,), chunk_it))


def get_rankbatch_parquetfile(
modelpath: Path,
folderpath: Path,
Expand All @@ -252,44 +234,42 @@ def get_rankbatch_parquetfile(
)

if not parquetfilepath.exists():
print(f"{parquetfilepath.relative_to(modelpath.parent)} does not exist")
print(f" generating {parquetfilepath.relative_to(modelpath.parent)}.")
estfilepaths = []
for mpirank in batch_mpiranks:
# not worth printing an error, because ranks with no cells to update do not produce an estimator file
with contextlib.suppress(FileNotFoundError):
estfilepath = at.firstexisting(f"estimators_{mpirank:04d}.out", folder=folderpath, tryzipped=True)
estfilepaths.append(estfilepath)

print(f" reading {len(estfilepaths)} estimator files from {folderpath.relative_to(Path(folderpath).parent)}")
print(
f" reading {len(estfilepaths)} estimator files from {folderpath.relative_to(Path(folderpath).parent)}...",
end="",
flush=True,
)

time_start = time.perf_counter()

pldf_group = None
pldf_batch = None
if at.get_config()["num_processes"] > 1:
with multiprocessing.get_context("spawn").Pool(processes=at.get_config()["num_processes"]) as pool:
for pldf_file in pool.imap(read_estimators_from_file, estfilepaths):
if pldf_group is None:
pldf_group = pldf_file
else:
pldf_group = pl.concat([pldf_group, pldf_file], how="diagonal_relaxed")
with multiprocessing.Pool(processes=at.get_config()["num_processes"]) as pool:
pldf_batch = pl.concat(pool.imap(read_estimators_from_file, estfilepaths), how="diagonal_relaxed")

pool.close()
pool.join()
pool.terminate()

else:
for pldf_file in (read_estimators_from_file(estfilepath) for estfilepath in estfilepaths):
pldf_group = (
pldf_file if pldf_group is None else pl.concat([pldf_group, pldf_file], how="diagonal_relaxed")
)
pldf_batch = pl.concat(map(read_estimators_from_file, estfilepaths), how="diagonal_relaxed")

print(f" took {time.perf_counter() - time_start:.1f} s")
print(
f"took {time.perf_counter() - time_start:.1f} s. Writing {parquetfilepath.relative_to(modelpath.parent)}..."
)

assert pldf_group is not None
print(f" writing {parquetfilepath.relative_to(modelpath.parent)}")
pldf_group.write_parquet(parquetfilepath, compression="zstd", statistics=True, compression_level=8)
assert pldf_batch is not None
pldf_batch.write_parquet(parquetfilepath, compression="zstd", statistics=True, compression_level=8)

filesize = parquetfilepath.stat().st_size / 1024 / 1024
print(f"Scanning {parquetfilepath.relative_to(modelpath.parent)} ({filesize:.2f} MiB)")
print(f" scanning {parquetfilepath.relative_to(modelpath.parent)} ({filesize:.2f} MiB)")

return parquetfilepath

Expand Down Expand Up @@ -344,7 +324,7 @@ def scan_estimators(
)
mpirank_groups = [
(batchindex, mpiranks)
for batchindex, mpiranks in enumerate(batched(mpiranklist, 100))
for batchindex, mpiranks in enumerate(at.misc.batched(mpiranklist, 100))
if mpiranks_matched.intersection(mpiranks)
]

Expand All @@ -355,6 +335,7 @@ def scan_estimators(
for runfolder in runfolders
for batchindex, mpiranks in mpirank_groups
)

assert bool(parquetfiles)

pldflazy = pl.concat([pl.scan_parquet(pfile) for pfile in parquetfiles], how="diagonal_relaxed").unique(
Expand Down
4 changes: 2 additions & 2 deletions artistools/gsinetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def plot_qdot_abund_modelcells(
print(f"Reading Qdot/thermo and abundance data for {len(list_particleids_getabund)} particles")

if at.get_config()["num_processes"] > 1:
with multiprocessing.get_context("spawn").Pool(processes=at.get_config()["num_processes"]) as pool:
with multiprocessing.Pool(processes=at.get_config()["num_processes"]) as pool:
list_particledata_withabund = pool.map(fworkerwithabund, list_particleids_getabund)
pool.close()
pool.join()
Expand All @@ -599,7 +599,7 @@ def plot_qdot_abund_modelcells(
print(f"Reading for Qdot/thermo data (no abundances needed) for {len(list_particleids_noabund)} particles")

if at.get_config()["num_processes"] > 1:
with multiprocessing.get_context("spawn").Pool(processes=at.get_config()["num_processes"]) as pool:
with multiprocessing.Pool(processes=at.get_config()["num_processes"]) as pool:
list_particledata_noabund = pool.map(fworkernoabund, list_particleids_noabund)
pool.close()
pool.join()
Expand Down
5 changes: 3 additions & 2 deletions artistools/inputmodel/describeinputmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def main(args: argparse.Namespace | None = None, argsraw: t.Sequence[str] | None
derived_cols=["mass_g", "vel_r_mid", "rho"],
)

dfmodel = dfmodel.filter(pl.col("rho") > 0.0)
dfmodel = dfmodel.drop("X_n") # don't confuse neutrons with Nitrogen
dfmodel = (
dfmodel.filter(pl.col("rho") > 0.0).drop("X_n") # don't confuse neutrons with Nitrogen
)

if args.noabund:
dfmodel = dfmodel.drop(cs.starts_with("X_"))
Expand Down
5 changes: 2 additions & 3 deletions artistools/inputmodel/energyinputfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,8 @@ def get_rprocess_calculation_files(path_to_rprocess_calculation, interpolate_tra

interpolated_trajectories.to_csv(path_to_rprocess_calculation / "interpolatedQdot.dat", sep=" ", index=False)
print(f"sum etot {sum(trajectory_E_tot)}")
trajectory_energy = {"id": trajectory_ids, "E_tot": trajectory_E_tot}
trajectory_energy = pd.DataFrame.from_dict(trajectory_energy)
trajectory_energy = trajectory_energy.sort_values(by="id")
trajectory_energy = pd.DataFrame.from_dict({"id": trajectory_ids, "E_tot": trajectory_E_tot}).sort_values(by="id")

print(trajectory_energy)
trajectory_energy.to_csv(path_to_rprocess_calculation / "trajectoryQ.dat", sep=" ", index=False)

Expand Down
2 changes: 1 addition & 1 deletion artistools/inputmodel/fromcmfgen/rd_cmfgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def rd_nuc_decay_data(file, quiet=False):
aiso_daughter[i] = np.rint(amu_daughter[i])
edec[i] = float(linearr[5]) * MEV2ERG # convert to ergs
seqnum.append(linearr[6])
if seqnum[-1] in ["F", "E"]:
if seqnum[-1] in ("F", "E"):
nchains = nchains + 1
nlines[i] = int(linearr[7])
if not quiet:
Expand Down
3 changes: 2 additions & 1 deletion artistools/inputmodel/rprocess_from_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,11 @@ def add_abundancecontributions(
trajworker = partial(get_trajectory_abund_q, t_model_s=t_model_s, traj_root=traj_root, getqdotintegral=True)

if at.get_config()["num_processes"] > 1:
with multiprocessing.get_context("spawn").Pool(processes=at.get_config()["num_processes"]) as pool:
with multiprocessing.Pool(processes=at.get_config()["num_processes"]) as pool:
list_traj_nuc_abund = pool.map(trajworker, particleids)
pool.close()
pool.join()
pool.terminate()
else:
list_traj_nuc_abund = [trajworker(particleid) for particleid in particleids]

Expand Down
9 changes: 4 additions & 5 deletions artistools/inputmodel/rprocess_solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,17 @@ def undecayed_z(row):
for _, row in dfsolarabund_undecayed.query("radioactive == True").iterrows():
rowdict[f"X_{at.get_elsymbol(int(row.Z))}{int(row.A)}"] = row.massfrac

modeldata = []
for mgi, densityrow in dfdensities.iterrows():
# print(mgi, densityrow)
modeldata.append(
modeldata = [
(
{
"inputcellid": mgi + 1,
"vel_r_max_kmps": densityrow["vel_r_max_kmps"],
"logrho": math.log10(densityrow["rho"]),
}
| rowdict
)
# print(modeldata)
for mgi, densityrow in dfdensities.iterrows()
]

dfmodel = pd.DataFrame(modeldata)
# print(dfmodel)
Expand Down
3 changes: 1 addition & 2 deletions artistools/inputmodel/slice1dfromconein3dmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def make_plot(args):
cone = make_cone(args)

cone = cone.loc[cone["rho_model"] > 0.0002] # cut low densities (empty cells?) from plot
fig = plt.figure()
ax = fig.gca(projection="3d")
ax = plt.figure().gca(projection="3d")

# print(cone['rho_model'])

Expand Down
7 changes: 2 additions & 5 deletions artistools/lightcurve/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ def get_from_packets(

dfpackets = dfpackets.select(getcols).collect(streaming=True).lazy()

npkts_selected = dfpackets.select(pl.count("*")).collect().item(0, 0)
print(f" {npkts_selected:.2e} packets")

lcdata = {}
for dirbin in directionbins:
if directionbins_are_vpkt_observers:
Expand All @@ -143,7 +140,7 @@ def get_from_packets(
elif average_over_phi:
assert not average_over_theta
solidanglefactor = ncosthetabins
pldfpackets_dirbin = dfpackets.filter(pl.col("costhetabin") * 10 == dirbin)
pldfpackets_dirbin = dfpackets.filter(pl.col("costhetabin") * nphibins == dirbin)
elif average_over_theta:
solidanglefactor = nphibins
pldfpackets_dirbin = dfpackets.filter(pl.col("phibin") == dirbin)
Expand All @@ -158,7 +155,7 @@ def get_from_packets(
sumcols=["e_rf"],
)

npkts_selected = pldfpackets_dirbin.select(pl.count("*")).collect().item(0, 0)
npkts_selected = pldfpackets_dirbin.select(pl.count("e_rf")).collect().item(0, 0)
print(f" dirbin {dirbin} contains {npkts_selected:.2e} packets")

unitfactor = float((u.erg / u.day).to("solLum"))
Expand Down
2 changes: 1 addition & 1 deletion artistools/lightcurve/plotlightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def plot_artis_lightcurve(
print("====> (no series label)")
else:
print(f"====> {modelname}")
print(f" folder: {modelpath.resolve().parts[-1]}")
print(f" modelpath: {modelpath.resolve().parts[-1]}")

if args is not None and args.title:
axis.set_title(modelname)
Expand Down
9 changes: 5 additions & 4 deletions artistools/linefluxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_packets_with_emtype_onefile(
def get_packets_with_emtype(
modelpath: Path | str, emtypecolumn: str, lineindices: t.Sequence[int], maxpacketfiles: int | None = None
):
packetsfiles = at.packets.get_packetsfilepaths(modelpath, maxpacketfiles=maxpacketfiles)
packetsfiles = at.packets.get_packets_text_paths(modelpath, maxpacketfiles=maxpacketfiles)
nprocs_read = len(packetsfiles)
assert nprocs_read > 0

Expand All @@ -74,7 +74,7 @@ def get_packets_with_emtype(
processfile = partial(get_packets_with_emtype_onefile, emtypecolumn, lineindices)
if at.get_config()["num_processes"] > 1:
print(f"Reading packets files with {at.get_config()['num_processes']} processes")
with multiprocessing.get_context("spawn").Pool(processes=at.get_config()["num_processes"]) as pool:
with multiprocessing.Pool(processes=at.get_config()["num_processes"]) as pool:
arr_dfmatchingpackets = pool.map(processfile, packetsfiles)
pool.close()
pool.join()
Expand Down Expand Up @@ -231,8 +231,9 @@ def get_closelines(
lowerlevelindex: int | None = None,
upperlevelindex: int | None = None,
):
dflinelist = at.get_linelist_dataframe(modelpath)
dflinelistclosematches = dflinelist.query("atomic_number == @atomic_number and ion_stage == @ion_stage").copy()
dflinelistclosematches = (
at.get_linelist_dataframe(modelpath).query("atomic_number == @atomic_number and ion_stage == @ion_stage").copy()
)
if lambdamin is not None:
dflinelistclosematches = dflinelistclosematches.query("@lambdamin < lambda_angstroms")
if lambdamax is not None:
Expand Down
Loading

0 comments on commit 515f56a

Please sign in to comment.