Skip to content

Commit

Permalink
Plot gamma spec and fnu spec
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeshingles committed Dec 17, 2024
1 parent 951bf8a commit f01c3c2
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 59 deletions.
137 changes: 88 additions & 49 deletions artistools/spectra/plotspectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,11 @@ def plot_artis_spectrum(
average_over_theta: bool = False,
usedegrees: bool = False,
maxpacketfiles: int | None = None,
xunit: str = "angstroms",
**plotkwargs,
) -> pl.DataFrame | None:
"""Plot an ARTIS output spectrum. The data plotted are also returned as a DataFrame."""
assert xunit in {"angstroms", "kev"}
modelpath = Path(modelpath)
if Path(modelpath).is_file(): # handle e.g. modelpath = 'modelpath/spec.out'
specfilename = Path(modelpath).parts[-1]
Expand Down Expand Up @@ -329,8 +331,8 @@ def plot_artis_spectrum(
if from_packets:
viewinganglespectra = atspectra.get_from_packets(
modelpath,
args.timemin,
args.timemax,
timelowdays=args.timemin,
timehighdays=args.timemax,
lambda_min=supxmin * 0.9,
lambda_max=supxmax * 1.1,
use_time=use_time,
Expand All @@ -341,6 +343,7 @@ def plot_artis_spectrum(
average_over_theta=average_over_theta,
fluxfilterfunc=filterfunc,
directionbins_are_vpkt_observers=args.plotvspecpol is not None,
gamma=args.gamma,
)

elif args.plotvspecpol is not None:
Expand Down Expand Up @@ -369,6 +372,7 @@ def plot_artis_spectrum(
average_over_phi=average_over_phi,
average_over_theta=average_over_theta,
fluxfilterfunc=filterfunc,
gamma=args.gamma,
)

dirbin_definitions = (
Expand Down Expand Up @@ -412,26 +416,23 @@ def plot_artis_spectrum(
if len(directionbins) > 1 or not linelabel_is_custom:
linelabel_withdirbin = f"{linelabel} {dirbin_definitions[dirbin]}"

atspectra.print_integrated_flux(dfspectrum["f_lambda"], dfspectrum["lambda_angstroms"])

if scale_to_peak:
dfspectrum = dfspectrum.with_columns(
f_lambda_scaled=pl.col("f_lambda") / pl.col("f_lambda").max() * scale_to_peak
)

ycolumnname = "f_lambda_scaled"
else:
ycolumnname = "f_lambda"
# atspectra.print_integrated_flux(dfspectrum["f_lambda"], dfspectrum["lambda_angstroms"])

if plotpacketcount:
ycolumnname = "packetcount"
dfspectrum = dfspectrum.with_columns(y=pl.col("packetcount"))
else:
dfspectrum = dfspectrum.with_columns(y=pl.col("f_lambda" if xunit == "angstroms" else "f_nu"))
if scale_to_peak:
dfspectrum = dfspectrum.with_columns(
y_scaled=pl.col("y") / pl.col("y").max() * scale_to_peak
).with_columns(y=pl.col("y_scaled"))

if args.binflux:
new_lambda_angstroms = []
binned_flux = []

wavelengths = dfspectrum["lambda_angstroms"]
fluxes = dfspectrum[ycolumnname]
fluxes = dfspectrum["y"]
nbins = 5

for i in np.arange(0, len(wavelengths - nbins), nbins, dtype=int):
Expand All @@ -442,13 +443,18 @@ def plot_artis_spectrum(
sum_flux = sum(fluxes[j] for j in range(i, i_max))
binned_flux.append(sum_flux / ncontribs)

dfspectrum = pl.DataFrame({"lambda_angstroms": new_lambda_angstroms, ycolumnname: binned_flux})
dfspectrum = pl.DataFrame({"lambda_angstroms": new_lambda_angstroms, "y": binned_flux})

if args.x == "angstroms":
dfspectrum = dfspectrum.with_columns(x=pl.col("lambda_angstroms"))
else:
h = 4.1356677e-15 # Planck's constant [eV s]
c = 2.99792458e18 # speed of light [angstroms/s]
dfspectrum = dfspectrum.with_columns(x=h * c / pl.col("lambda_angstroms") / 1000.0, y=pl.col("f_nu"))
dfspectrum = dfspectrum.sort("x", maintain_order=True)

axis.plot(
dfspectrum["lambda_angstroms"],
dfspectrum[ycolumnname],
label=linelabel_withdirbin if axindex == 0 else None,
**plotkwargs,
dfspectrum["x"], dfspectrum["y"], label=linelabel_withdirbin if axindex == 0 else None, **plotkwargs
)

return dfspectrum[["lambda_angstroms", "f_lambda"]]
Expand Down Expand Up @@ -558,6 +564,7 @@ def make_spectrum_plot(
average_over_phi=args.average_over_phi_angle,
average_over_theta=args.average_over_theta_angle,
usedegrees=args.usedegrees,
xunit=args.x,
**plotkwargs,
)
except FileNotFoundError as e:
Expand Down Expand Up @@ -607,13 +614,15 @@ def make_spectrum_plot(
# zorder=-1,
# )

if args.stokesparam == "I":
if args.stokesparam == "I" and not args.logscaley:
axis.set_ylim(bottom=0.0)
if args.normalised:
axis.set_ylim(top=1.25)
axis.set_ylabel(r"Scaled F$_\lambda$")

if args.plotpacketcount:
axis.set_ylabel(r"Monte Carlo packets per bin")
elif args.normalised:
axis.set_ylim(top=1.25)
axis.set_ylabel(r"Scaled F$_\lambda$")

if not args.notitle and args.title:
if args.inset_title:
axis.annotate(
Expand Down Expand Up @@ -1008,33 +1017,46 @@ def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
dfalldata: pl.DataFrame | None = pl.DataFrame()

if not args.hideyticklabels:
ylabel = None
if args.x == "angstroms":
ylabel = r"F$_\lambda$ at 1 Mpc [{}erg/s/cm$^2$/$\mathrm{{\AA}}$]"
elif args.x.lower() == "kev":
ylabel = r"F$_\nu$ at 1 Mpc [{}erg/s/cm$^2$/Hz]"

assert ylabel is not None
if args.logscaley:
# don't include the {} that will be replaced with the power of 10 by the custom formatter
ylabel = ylabel.replace("{}", "")

if args.multispecplot:
for ax in axes:
ax.set_ylabel(r"F$_\lambda$ at 1 Mpc [{}erg/s/cm$^2$/$\mathrm{{\AA}}$]")

elif args.logscale:
# don't include the {} that will be replaced with the power of 10 by the custom formatter
axes[-1].set_ylabel(r"F$_\lambda$ at 1 Mpc [erg/s/cm$^2$/$\mathrm{{\AA}}$]")
ax.set_ylabel(ylabel)
else:
axes[-1].set_ylabel(r"F$_\lambda$ at 1 Mpc [{}erg/s/cm$^2$/$\mathrm{{\AA}}$]")
axes[-1].set_ylabel(ylabel)

for axis in axes:
if args.logscale:
if args.xmin is not None:
axis.set_xlim(left=args.xmin)
if args.xmax is not None:
axis.set_xlim(right=args.xmax)
if args.logscalex:
axis.set_xscale("log")
if args.logscaley:
axis.set_yscale("log")
axis.set_xlim(left=args.xmin, right=args.xmax)

if (args.xmax - args.xmin) < 2000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=100))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=10))
elif (args.xmax - args.xmin) < 11000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=1000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=100))
elif (args.xmax - args.xmin) < 14000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))
else:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))

if not args.gamma:
if (args.xmax - args.xmin) < 2000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=100))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=10))
elif (args.xmax - args.xmin) < 11000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=1000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=100))
elif (args.xmax - args.xmin) < 14000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))
else:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))

if densityplotyvars:
make_contrib_plot(axes[:-1], args.specpath[0], densityplotyvars, args)
Expand Down Expand Up @@ -1100,7 +1122,7 @@ def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
for index, ax in enumerate(axes):
# ax.xaxis.set_major_formatter(plt.NullFormatter())

if "{" in ax.get_ylabel() and not args.logscale:
if "{" in ax.get_ylabel() and not args.logscaley:
ax.yaxis.set_major_formatter(ExponentLabelFormatter(ax.get_ylabel(), decimalplaces=1))

if args.hidexticklabels:
Expand All @@ -1114,7 +1136,10 @@ def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
ax.text(5500, ymax * 0.9, f"{args.timedayslist[index]} days") # multispecplot text

if not args.hidexticklabels:
axes[-1].set_xlabel(r"Wavelength $\left[\mathrm{{\AA}}\right]$")
if args.x == "angstroms":
axes[-1].set_xlabel(r"Wavelength $\left[\mathrm{{\AA}}\right]$")
else:
axes[-1].set_xlabel(r"Energy $\left[\mathrm{{keV}}\right]$")

if not args.outputfile:
args.outputfile = defaultoutputfile
Expand Down Expand Up @@ -1164,6 +1189,8 @@ def addargs(parser) -> None:

parser.add_argument("-dashes", default=[], nargs="*", help="Dashes property of lines")

parser.add_argument("--gamma", action="store_true", help="Make light curve from gamma rays instead of R-packets")

parser.add_argument("--greyscale", action="store_true", help="Plot in greyscale")

parser.add_argument(
Expand Down Expand Up @@ -1232,12 +1259,14 @@ def addargs(parser) -> None:
"--notimeclamp", action="store_true", help="When plotting from packets, don't clamp to timestep start/end"
)

parser.add_argument("-x", default=None, choices=["angstroms", "kev", "hz"], help="x (horizontal) axis unit")

parser.add_argument(
"-xmin", "-lambdamin", dest="xmin", type=int, default=2500, help="Plot range: minimum wavelength in Angstroms"
"-xmin", "-lambdamin", dest="xmin", type=float, default=None, help="Plot range: minimum x range"
)

parser.add_argument(
"-xmax", "-lambdamax", dest="xmax", type=int, default=19000, help="Plot range: maximum wavelength in Angstroms"
"-xmax", "-lambdamax", dest="xmax", type=float, default=None, help="Plot range: maximum x range"
)

parser.add_argument(
Expand Down Expand Up @@ -1304,7 +1333,9 @@ def addargs(parser) -> None:
"-figscale", type=float, default=1.8, help="Scale factor for plot area. 1.0 is for single-column"
)

parser.add_argument("--logscale", action="store_true", help="Use log scale")
parser.add_argument("--logscalex", action="store_true", help="Use log scale for x values")

parser.add_argument("--logscaley", action="store_true", help="Use log scale for y values")

parser.add_argument("--hidenetspectrum", action="store_true", help="Hide net spectrum")

Expand Down Expand Up @@ -1426,6 +1457,14 @@ def main(args: argparse.Namespace | None = None, argsraw: Sequence[str] | None =
print("WARNING: --average_every_tenth_viewing_angle is deprecated. use --average_over_phi_angle instead")
args.average_over_phi_angle = True

if args.x is None:
args.x = "kev" if args.gamma else "angstroms"

if args.xmin is None and not args.gamma:
args.xmin = 2500
if args.xmax is None and not args.gamma:
args.xmax = 19000

set_mpl_style()

assert (
Expand Down
30 changes: 20 additions & 10 deletions artistools/spectra/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def get_exspec_bins(
mnubins: int | None = None,
nu_min_r: float | None = None,
nu_max_r: float | None = None,
gamma: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get the wavelength bins for the emergent spectrum."""
if modelpath is not None:
dfspec = read_spec(modelpath)
dfspec = read_spec(modelpath, gamma=gamma)
if mnubins is None:
mnubins = dfspec.height

Expand Down Expand Up @@ -160,6 +161,7 @@ def get_from_packets(
fluxfilterfunc: Callable[[npt.NDArray[np.floating] | pl.Series], npt.NDArray[np.floating]] | None = None,
nprocs_read_dfpackets: tuple[int, pl.DataFrame | pl.LazyFrame] | None = None,
directionbins_are_vpkt_observers: bool = False,
gamma: bool = False,
) -> dict[int, pl.DataFrame]:
"""Get a spectrum dataframe using the packets files as input."""
if directionbins is None:
Expand All @@ -175,7 +177,7 @@ def get_from_packets(
lambda_bin_edges = np.arange(lambda_min, lambda_max + delta_lambda, delta_lambda)
lambda_bin_centres = 0.5 * (lambda_bin_edges[:-1] + lambda_bin_edges[1:]) # bin centres
else:
lambda_bin_edges, lambda_bin_centres, delta_lambda = get_exspec_bins(modelpath=modelpath)
lambda_bin_edges, lambda_bin_centres, delta_lambda = get_exspec_bins(modelpath=modelpath, gamma=gamma)
lambda_min = lambda_bin_centres[0]
lambda_max = lambda_bin_centres[-1]

Expand All @@ -189,10 +191,14 @@ def get_from_packets(
nprocs_read = nprocs_read_dfpackets[0]
dfpackets = nprocs_read_dfpackets[1].lazy()
elif directionbins_are_vpkt_observers:
assert not gamma
nprocs_read, dfpackets = atpackets.get_virtual_packets_pl(modelpath, maxpacketfiles=maxpacketfiles)
else:
nprocs_read, dfpackets = atpackets.get_packets_pl(
modelpath, maxpacketfiles=maxpacketfiles, packet_type="TYPE_ESCAPE", escape_type="TYPE_RPKT"
modelpath,
maxpacketfiles=maxpacketfiles,
packet_type="TYPE_ESCAPE",
escape_type="TYPE_GAMMA" if gamma else "TYPE_RPKT",
)

dfpackets = dfpackets.with_columns([
Expand Down Expand Up @@ -335,7 +341,8 @@ def get_from_packets(
"lambda_angstroms",
pl.col(f"f_lambda_dirbin{dirbin}").alias("f_lambda"),
pl.col(f"count_dirbin{dirbin}").alias("packetcount"),
])
(299792458.0 / (pl.col("lambda_angstroms") * 1e-10)).alias("nu"),
]).with_columns(f_nu=(pl.col("f_lambda") * pl.col("lambda_angstroms") / pl.col("nu")))
if nprocs_read_dfpackets is None:
npkts_selected = dfdict[dirbin].get_column("packetcount").sum()
print(f" dirbin {dirbin:2d} plots {npkts_selected:.2e} packets")
Expand All @@ -344,8 +351,8 @@ def get_from_packets(


@lru_cache(maxsize=16)
def read_spec(modelpath: Path) -> pl.DataFrame:
specfilename = firstexisting("spec.out", folder=modelpath, tryzipped=True)
def read_spec(modelpath: Path, gamma: bool = False) -> pl.DataFrame:
specfilename = firstexisting("gamma_spec.out" if gamma else "spec.out", folder=modelpath, tryzipped=True)
print(f"Reading {specfilename}")

return (
Expand Down Expand Up @@ -445,6 +452,7 @@ def get_spectrum(
average_over_theta: bool = False,
average_over_phi: bool = False,
stokesparam: t.Literal["I", "Q", "U"] = "I",
gamma: bool = False,
) -> dict[int, pl.DataFrame]:
"""Return a pandas DataFrame containing an ARTIS emergent spectrum."""
if timestepmax is None or timestepmax < 0:
Expand All @@ -471,7 +479,7 @@ def get_spectrum(
# spherically averaged spectra
if stokesparam == "I":
try:
specdata[-1] = read_spec(modelpath=modelpath)
specdata[-1] = read_spec(modelpath=modelpath, gamma=gamma)

except FileNotFoundError:
specdata[-1] = get_specpol_data(angle=-1, modelpath=modelpath)[stokesparam]
Expand Down Expand Up @@ -500,9 +508,11 @@ def get_spectrum(

arr_nu = specdata[dirbin]["nu"]
arr_lambda = 2.99792458e18 / arr_nu
dfspectrum = pl.DataFrame({"lambda_angstroms": arr_lambda, "f_lambda": arr_f_nu * arr_nu / arr_lambda}).sort(
by="lambda_angstroms"
)
dfspectrum = pl.DataFrame({
"lambda_angstroms": arr_lambda,
"f_lambda": arr_f_nu * arr_nu / arr_lambda,
"f_nu": arr_f_nu,
}).sort(by="lambda_angstroms")

if fluxfilterfunc:
if dirbin == directionbins[0]:
Expand Down

0 comments on commit f01c3c2

Please sign in to comment.