Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plotspectra: enable gamma spectra plotting and choose x axis unit with -x "hz" or -x "kev" #280

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 88 additions & 49 deletions artistools/spectra/plotspectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,11 @@ def plot_artis_spectrum(
average_over_theta: bool = False,
usedegrees: bool = False,
maxpacketfiles: int | None = None,
xunit: str = "angstroms",
**plotkwargs,
) -> pl.DataFrame | None:
"""Plot an ARTIS output spectrum. The data plotted are also returned as a DataFrame."""
assert xunit in {"angstroms", "kev"}
modelpath = Path(modelpath)
if Path(modelpath).is_file(): # handle e.g. modelpath = 'modelpath/spec.out'
specfilename = Path(modelpath).parts[-1]
Expand Down Expand Up @@ -329,8 +331,8 @@ def plot_artis_spectrum(
if from_packets:
viewinganglespectra = atspectra.get_from_packets(
modelpath,
args.timemin,
args.timemax,
timelowdays=args.timemin,
timehighdays=args.timemax,
lambda_min=supxmin * 0.9,
lambda_max=supxmax * 1.1,
use_time=use_time,
Expand All @@ -341,6 +343,7 @@ def plot_artis_spectrum(
average_over_theta=average_over_theta,
fluxfilterfunc=filterfunc,
directionbins_are_vpkt_observers=args.plotvspecpol is not None,
gamma=args.gamma,
)

elif args.plotvspecpol is not None:
Expand Down Expand Up @@ -369,6 +372,7 @@ def plot_artis_spectrum(
average_over_phi=average_over_phi,
average_over_theta=average_over_theta,
fluxfilterfunc=filterfunc,
gamma=args.gamma,
)

dirbin_definitions = (
Expand Down Expand Up @@ -412,26 +416,23 @@ def plot_artis_spectrum(
if len(directionbins) > 1 or not linelabel_is_custom:
linelabel_withdirbin = f"{linelabel} {dirbin_definitions[dirbin]}"

atspectra.print_integrated_flux(dfspectrum["f_lambda"], dfspectrum["lambda_angstroms"])

if scale_to_peak:
dfspectrum = dfspectrum.with_columns(
f_lambda_scaled=pl.col("f_lambda") / pl.col("f_lambda").max() * scale_to_peak
)

ycolumnname = "f_lambda_scaled"
else:
ycolumnname = "f_lambda"
# atspectra.print_integrated_flux(dfspectrum["f_lambda"], dfspectrum["lambda_angstroms"])

if plotpacketcount:
ycolumnname = "packetcount"
dfspectrum = dfspectrum.with_columns(y=pl.col("packetcount"))
else:
dfspectrum = dfspectrum.with_columns(y=pl.col("f_lambda" if xunit == "angstroms" else "f_nu"))
if scale_to_peak:
dfspectrum = dfspectrum.with_columns(
y_scaled=pl.col("y") / pl.col("y").max() * scale_to_peak
).with_columns(y=pl.col("y_scaled"))

if args.binflux:
new_lambda_angstroms = []
binned_flux = []

wavelengths = dfspectrum["lambda_angstroms"]
fluxes = dfspectrum[ycolumnname]
fluxes = dfspectrum["y"]
nbins = 5

for i in np.arange(0, len(wavelengths - nbins), nbins, dtype=int):
Expand All @@ -442,13 +443,18 @@ def plot_artis_spectrum(
sum_flux = sum(fluxes[j] for j in range(i, i_max))
binned_flux.append(sum_flux / ncontribs)

dfspectrum = pl.DataFrame({"lambda_angstroms": new_lambda_angstroms, ycolumnname: binned_flux})
dfspectrum = pl.DataFrame({"lambda_angstroms": new_lambda_angstroms, "y": binned_flux})

if args.x == "angstroms":
dfspectrum = dfspectrum.with_columns(x=pl.col("lambda_angstroms"))
else:
h = 4.1356677e-15 # Planck's constant [eV s]
c = 2.99792458e18 # speed of light [angstroms/s]
dfspectrum = dfspectrum.with_columns(x=h * c / pl.col("lambda_angstroms") / 1000.0, y=pl.col("f_nu"))
dfspectrum = dfspectrum.sort("x", maintain_order=True)

axis.plot(
dfspectrum["lambda_angstroms"],
dfspectrum[ycolumnname],
label=linelabel_withdirbin if axindex == 0 else None,
**plotkwargs,
dfspectrum["x"], dfspectrum["y"], label=linelabel_withdirbin if axindex == 0 else None, **plotkwargs
)

return dfspectrum[["lambda_angstroms", "f_lambda"]]
Expand Down Expand Up @@ -558,6 +564,7 @@ def make_spectrum_plot(
average_over_phi=args.average_over_phi_angle,
average_over_theta=args.average_over_theta_angle,
usedegrees=args.usedegrees,
xunit=args.x,
**plotkwargs,
)
except FileNotFoundError as e:
Expand Down Expand Up @@ -607,13 +614,15 @@ def make_spectrum_plot(
# zorder=-1,
# )

if args.stokesparam == "I":
if args.stokesparam == "I" and not args.logscaley:
axis.set_ylim(bottom=0.0)
if args.normalised:
axis.set_ylim(top=1.25)
axis.set_ylabel(r"Scaled F$_\lambda$")

if args.plotpacketcount:
axis.set_ylabel(r"Monte Carlo packets per bin")
elif args.normalised:
axis.set_ylim(top=1.25)
axis.set_ylabel(r"Scaled F$_\lambda$")

if not args.notitle and args.title:
if args.inset_title:
axis.annotate(
Expand Down Expand Up @@ -1008,33 +1017,46 @@ def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
dfalldata: pl.DataFrame | None = pl.DataFrame()

if not args.hideyticklabels:
ylabel = None
if args.x == "angstroms":
ylabel = r"F$_\lambda$ at 1 Mpc [{}erg/s/cm$^2$/$\mathrm{{\AA}}$]"
elif args.x.lower() == "kev":
ylabel = r"F$_\nu$ at 1 Mpc [{}erg/s/cm$^2$/Hz]"

assert ylabel is not None
if args.logscaley:
# don't include the {} that will be replaced with the power of 10 by the custom formatter
ylabel = ylabel.replace("{}", "")

if args.multispecplot:
for ax in axes:
ax.set_ylabel(r"F$_\lambda$ at 1 Mpc [{}erg/s/cm$^2$/$\mathrm{{\AA}}$]")

elif args.logscale:
# don't include the {} that will be replaced with the power of 10 by the custom formatter
axes[-1].set_ylabel(r"F$_\lambda$ at 1 Mpc [erg/s/cm$^2$/$\mathrm{{\AA}}$]")
ax.set_ylabel(ylabel)
else:
axes[-1].set_ylabel(r"F$_\lambda$ at 1 Mpc [{}erg/s/cm$^2$/$\mathrm{{\AA}}$]")
axes[-1].set_ylabel(ylabel)

for axis in axes:
if args.logscale:
if args.xmin is not None:
axis.set_xlim(left=args.xmin)
if args.xmax is not None:
axis.set_xlim(right=args.xmax)
if args.logscalex:
axis.set_xscale("log")
if args.logscaley:
axis.set_yscale("log")
axis.set_xlim(left=args.xmin, right=args.xmax)

if (args.xmax - args.xmin) < 2000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=100))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=10))
elif (args.xmax - args.xmin) < 11000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=1000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=100))
elif (args.xmax - args.xmin) < 14000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))
else:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))

if not args.gamma:
if (args.xmax - args.xmin) < 2000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=100))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=10))
elif (args.xmax - args.xmin) < 11000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=1000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=100))
elif (args.xmax - args.xmin) < 14000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))
else:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))

if densityplotyvars:
make_contrib_plot(axes[:-1], args.specpath[0], densityplotyvars, args)
Expand Down Expand Up @@ -1100,7 +1122,7 @@ def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
for index, ax in enumerate(axes):
# ax.xaxis.set_major_formatter(plt.NullFormatter())

if "{" in ax.get_ylabel() and not args.logscale:
if "{" in ax.get_ylabel() and not args.logscaley:
ax.yaxis.set_major_formatter(ExponentLabelFormatter(ax.get_ylabel(), decimalplaces=1))

if args.hidexticklabels:
Expand All @@ -1114,7 +1136,10 @@ def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
ax.text(5500, ymax * 0.9, f"{args.timedayslist[index]} days") # multispecplot text

if not args.hidexticklabels:
axes[-1].set_xlabel(r"Wavelength $\left[\mathrm{{\AA}}\right]$")
if args.x == "angstroms":
axes[-1].set_xlabel(r"Wavelength $\left[\mathrm{{\AA}}\right]$")
else:
axes[-1].set_xlabel(r"Energy $\left[\mathrm{{keV}}\right]$")

if not args.outputfile:
args.outputfile = defaultoutputfile
Expand Down Expand Up @@ -1164,6 +1189,8 @@ def addargs(parser) -> None:

parser.add_argument("-dashes", default=[], nargs="*", help="Dashes property of lines")

parser.add_argument("--gamma", action="store_true", help="Make light curve from gamma rays instead of R-packets")

parser.add_argument("--greyscale", action="store_true", help="Plot in greyscale")

parser.add_argument(
Expand Down Expand Up @@ -1232,12 +1259,14 @@ def addargs(parser) -> None:
"--notimeclamp", action="store_true", help="When plotting from packets, don't clamp to timestep start/end"
)

parser.add_argument("-x", default=None, choices=["angstroms", "kev", "hz"], help="x (horizontal) axis unit")

parser.add_argument(
"-xmin", "-lambdamin", dest="xmin", type=int, default=2500, help="Plot range: minimum wavelength in Angstroms"
"-xmin", "-lambdamin", dest="xmin", type=float, default=None, help="Plot range: minimum x range"
)

parser.add_argument(
"-xmax", "-lambdamax", dest="xmax", type=int, default=19000, help="Plot range: maximum wavelength in Angstroms"
"-xmax", "-lambdamax", dest="xmax", type=float, default=None, help="Plot range: maximum x range"
)

parser.add_argument(
Expand Down Expand Up @@ -1304,7 +1333,9 @@ def addargs(parser) -> None:
"-figscale", type=float, default=1.8, help="Scale factor for plot area. 1.0 is for single-column"
)

parser.add_argument("--logscale", action="store_true", help="Use log scale")
parser.add_argument("--logscalex", action="store_true", help="Use log scale for x values")

parser.add_argument("--logscaley", action="store_true", help="Use log scale for y values")

parser.add_argument("--hidenetspectrum", action="store_true", help="Hide net spectrum")

Expand Down Expand Up @@ -1426,6 +1457,14 @@ def main(args: argparse.Namespace | None = None, argsraw: Sequence[str] | None =
print("WARNING: --average_every_tenth_viewing_angle is deprecated. use --average_over_phi_angle instead")
args.average_over_phi_angle = True

if args.x is None:
args.x = "kev" if args.gamma else "angstroms"

if args.xmin is None and not args.gamma:
args.xmin = 2500
if args.xmax is None and not args.gamma:
args.xmax = 19000

set_mpl_style()

assert (
Expand Down
30 changes: 20 additions & 10 deletions artistools/spectra/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def get_exspec_bins(
mnubins: int | None = None,
nu_min_r: float | None = None,
nu_max_r: float | None = None,
gamma: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get the wavelength bins for the emergent spectrum."""
if modelpath is not None:
dfspec = read_spec(modelpath)
dfspec = read_spec(modelpath, gamma=gamma)
if mnubins is None:
mnubins = dfspec.height

Expand Down Expand Up @@ -160,6 +161,7 @@ def get_from_packets(
fluxfilterfunc: Callable[[npt.NDArray[np.floating] | pl.Series], npt.NDArray[np.floating]] | None = None,
nprocs_read_dfpackets: tuple[int, pl.DataFrame | pl.LazyFrame] | None = None,
directionbins_are_vpkt_observers: bool = False,
gamma: bool = False,
) -> dict[int, pl.DataFrame]:
"""Get a spectrum dataframe using the packets files as input."""
if directionbins is None:
Expand All @@ -175,7 +177,7 @@ def get_from_packets(
lambda_bin_edges = np.arange(lambda_min, lambda_max + delta_lambda, delta_lambda)
lambda_bin_centres = 0.5 * (lambda_bin_edges[:-1] + lambda_bin_edges[1:]) # bin centres
else:
lambda_bin_edges, lambda_bin_centres, delta_lambda = get_exspec_bins(modelpath=modelpath)
lambda_bin_edges, lambda_bin_centres, delta_lambda = get_exspec_bins(modelpath=modelpath, gamma=gamma)
lambda_min = lambda_bin_centres[0]
lambda_max = lambda_bin_centres[-1]

Expand All @@ -189,10 +191,14 @@ def get_from_packets(
nprocs_read = nprocs_read_dfpackets[0]
dfpackets = nprocs_read_dfpackets[1].lazy()
elif directionbins_are_vpkt_observers:
assert not gamma
nprocs_read, dfpackets = atpackets.get_virtual_packets_pl(modelpath, maxpacketfiles=maxpacketfiles)
else:
nprocs_read, dfpackets = atpackets.get_packets_pl(
modelpath, maxpacketfiles=maxpacketfiles, packet_type="TYPE_ESCAPE", escape_type="TYPE_RPKT"
modelpath,
maxpacketfiles=maxpacketfiles,
packet_type="TYPE_ESCAPE",
escape_type="TYPE_GAMMA" if gamma else "TYPE_RPKT",
)

dfpackets = dfpackets.with_columns([
Expand Down Expand Up @@ -335,7 +341,8 @@ def get_from_packets(
"lambda_angstroms",
pl.col(f"f_lambda_dirbin{dirbin}").alias("f_lambda"),
pl.col(f"count_dirbin{dirbin}").alias("packetcount"),
])
(299792458.0 / (pl.col("lambda_angstroms") * 1e-10)).alias("nu"),
]).with_columns(f_nu=(pl.col("f_lambda") * pl.col("lambda_angstroms") / pl.col("nu")))
if nprocs_read_dfpackets is None:
npkts_selected = dfdict[dirbin].get_column("packetcount").sum()
print(f" dirbin {dirbin:2d} plots {npkts_selected:.2e} packets")
Expand All @@ -344,8 +351,8 @@ def get_from_packets(


@lru_cache(maxsize=16)
def read_spec(modelpath: Path) -> pl.DataFrame:
specfilename = firstexisting("spec.out", folder=modelpath, tryzipped=True)
def read_spec(modelpath: Path, gamma: bool = False) -> pl.DataFrame:
specfilename = firstexisting("gamma_spec.out" if gamma else "spec.out", folder=modelpath, tryzipped=True)
print(f"Reading {specfilename}")

return (
Expand Down Expand Up @@ -445,6 +452,7 @@ def get_spectrum(
average_over_theta: bool = False,
average_over_phi: bool = False,
stokesparam: t.Literal["I", "Q", "U"] = "I",
gamma: bool = False,
) -> dict[int, pl.DataFrame]:
"""Return a pandas DataFrame containing an ARTIS emergent spectrum."""
if timestepmax is None or timestepmax < 0:
Expand All @@ -471,7 +479,7 @@ def get_spectrum(
# spherically averaged spectra
if stokesparam == "I":
try:
specdata[-1] = read_spec(modelpath=modelpath)
specdata[-1] = read_spec(modelpath=modelpath, gamma=gamma)

except FileNotFoundError:
specdata[-1] = get_specpol_data(angle=-1, modelpath=modelpath)[stokesparam]
Expand Down Expand Up @@ -500,9 +508,11 @@ def get_spectrum(

arr_nu = specdata[dirbin]["nu"]
arr_lambda = 2.99792458e18 / arr_nu
dfspectrum = pl.DataFrame({"lambda_angstroms": arr_lambda, "f_lambda": arr_f_nu * arr_nu / arr_lambda}).sort(
by="lambda_angstroms"
)
dfspectrum = pl.DataFrame({
"lambda_angstroms": arr_lambda,
"f_lambda": arr_f_nu * arr_nu / arr_lambda,
"f_nu": arr_f_nu,
}).sort(by="lambda_angstroms")

if fluxfilterfunc:
if dirbin == directionbins[0]:
Expand Down
Loading