Skip to content

Commit

Permalink
refactored datasets code create CSV/tfrecords concurrently (relates to
Browse files Browse the repository at this point in the history
…#79)

Big change
Previously creating a dataset was a two step process;
1. generate full set of instances with params for each and save to CSV file(s)
2. make the corresponding tfrecord(s) from the CSVs, using the params to create the model LC upon which the mags feature is based.  This is also where train/val/test split happens.

With this change, both the csvs and tfrecords are written concurrently. The benefit being that we now assemble all of the data for each instance (including mags feature) before saving to both csv and tfrecord, meaning we can add logic which may discard or modify the instance based on the outcome of generating the mags feature. Previously, membership and labels were decided at the CSV stage and these could not be updated when generating the mags data.

Because of changes (fixes) to the random generators/seeds datasets generated from here will not have same members as previously. However, consistency is now improved going forward.

Fringe benefit; the new code uses less RAM and is slightly faster than the previous approach.
  • Loading branch information
SteveOv committed Aug 2, 2024
1 parent 3ec9c9b commit de37f60
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 307 deletions.
86 changes: 42 additions & 44 deletions make_synthetic_test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from pathlib import Path
from contextlib import redirect_stdout
import hashlib

import numpy as np

Expand Down Expand Up @@ -62,7 +63,11 @@ def generate_instances_from_mist_models(instance_count: int, label: str, verbose
# pylint: disable=too-many-locals, too-many-statements, invalid-name
generated_counter = 0
usable_counter = 0
set_id = label.replace("trainset", "")
set_id = ''.join(filter(str.isdigit, label))

# Don't use the built-in hash() function; it's not consistent across processes!!!
seed = int.from_bytes(hashlib.shake_128(label.encode("utf8")).digest(8))
rng = np.random.default_rng(seed)

mission = Mission.get_instance("TESS")

Expand All @@ -74,10 +79,10 @@ def generate_instances_from_mist_models(instance_count: int, label: str, verbose
while usable_counter < instance_count:
while True: # imitate "loop and a half" / do ... until logic
# Get a list of initial masses at a random metallicity & age to choose our stars from
feh = np.random.choice(feh_values)
feh = rng.choice(feh_values)
ages = _mist_isochones.list_ages(feh, min_phase, max_phase)
while True:
age = np.random.choice(ages) * u.dex(u.yr)
age = rng.choice(ages) * u.dex(u.yr)
init_masses = _mist_isochones.list_initial_masses(feh, age.value,
min_phase, max_phase,
min_mass_value, max_mass_value)
Expand All @@ -89,11 +94,11 @@ def generate_instances_from_mist_models(instance_count: int, label: str, verbose
probs = np.power(init_masses, -2.35) # Salpeter IMF
probs *= np.tanh(0.31 * init_masses + .18) # Wells & Prsa Primary multiplicity frac
probs = np.divide(probs, np.sum(probs)) # Scaled to get a pmf() == 1
init_MA = np.random.choice(init_masses, p=probs) * u.solMass
init_MA = rng.choice(init_masses, p=probs) * u.solMass

init_MB_mask = (init_masses >= min_mass_value) & (init_masses < init_MA.value)
if any(init_MB_mask):
init_MB = np.random.choice(init_masses[init_MB_mask]) * u.solMass
init_MB = rng.choice(init_masses[init_MB_mask]) * u.solMass
else:
init_MB = init_MA

Expand All @@ -115,12 +120,12 @@ def generate_instances_from_mist_models(instance_count: int, label: str, verbose
# . 3(RA+RB) / 2(1-e) (Wells & Prsa) (assuming e==0 for now)
# . max(5*RA, 5*RB) (based on JKTEBOP recommendation for rA <= 0.2, rB <= 0.2)
a_min = max(3/2*(RA+RB), 5*RA, 5*RB)
per_min = orbital.orbital_period(MA, MB, a_min)
per_min = orbital.orbital_period(MA, MB, a_min).to(u.d).value

# We generate period, inc, and omega (argument of periastron) from uniform distributions
per = np.random.uniform(low=per_min.to(u.d).value, high=25) * u.d
inc = np.random.uniform(low=50., high=90.00001) * u.deg
omega = np.random.uniform(low=0., high=360.) * u.deg
per = rng.uniform(low=per_min, high=max(per_min, 25)) * u.d
inc = rng.uniform(low=50., high=90.00001) * u.deg
omega = rng.uniform(low=0., high=360.) * u.deg

# Eccentricity from uniform distribution, subject to a maximum value which depends on
# orbital period/seperation (again, based on Wells & Prsa; Moe & Di Stefano)
Expand All @@ -129,11 +134,11 @@ def generate_instances_from_mist_models(instance_count: int, label: str, verbose
ecc = 0
else:
e_max = max(min(1-(per.value/2)**(-2/3), 1-(1.5*(RA+RB)/a).value), 0)
ecc = np.random.uniform(low=0, high=e_max)
ecc = rng.uniform(low=0, high=e_max)

# We're once more predicting L3 as JKTEBOP is being updated to support
# negative L3 input values (so it's now fully trainable)
L3 = np.random.normal(0., 0.1)
L3 = rng.normal(0., 0.1)
L3 = 0 # continue to override this as L3 doesn't train well

# Now we can calculate other params which we need to decide whether to use this
Expand Down Expand Up @@ -165,7 +170,7 @@ def generate_instances_from_mist_models(instance_count: int, label: str, verbose
# Now we have to decide an appropriate Gaussian noise SNR to apply.
# Randomly choose an apparent mag in the TESS photometric range then derive
# the SNR (based on a linear regression fit of Álvarez et al. (2024) Table 2).
apparent_mag = np.random.uniform(6, 18)
apparent_mag = rng.uniform(6, 18)
snr = np.add(np.multiply(apparent_mag, -2.32), 59.4)

yield {
Expand Down Expand Up @@ -225,35 +230,28 @@ def generate_instances_from_mist_models(instance_count: int, label: str, verbose
# which generates random plausible dEB systems based on MIST stellar models.
# ------------------------------------------------------------------------------
if __name__ == "__main__":
with redirect_stdout(Tee(open(dataset_dir / "trainset.log",
"w",
encoding="utf8"))):
datasets.generate_dataset_csvs(instance_count=DATASET_SIZE,
file_count=10,
output_dir=dataset_dir,
generator_func=generate_instances_from_mist_models,
file_pattern="trainset{0:03d}.csv",
verbose=True,
simulate=False)

plots.plot_trainset_histograms(dataset_dir, dataset_dir / "synth-histogram-full.png", cols=4)
plots.plot_trainset_histograms(dataset_dir, dataset_dir / "synth-histogram-main.eps", cols=2,
params=["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"])

with redirect_stdout(Tee(open(dataset_dir/"dataset.log",
"a" if RESUME else "w",
encoding="utf8"))):
datasets.make_dataset_files(trainset_files=sorted(dataset_dir.glob("trainset*.csv")),
output_dir=dataset_dir,
valid_ratio=0.,
test_ratio=1.,
resume=RESUME,
max_workers=5,
verbose=True,
simulate=False)

# Simple diagnostic plot of the mags feature of randomly sampled instances.
dataset_files = sorted(dataset_dir.glob("**/*.tfrecord"))
ids, _, _, _ = deb_example.read_dataset(dataset_files)
fig = plots.plot_dataset_instance_mags_features(dataset_files, np.random.choice(ids, 30))
fig.savefig(dataset_dir / "sample.pdf")

with redirect_stdout(Tee(open(dataset_dir/"dataset.log", "w", encoding="utf8"))):
datasets.make_dataset(instance_count=DATASET_SIZE,
file_count=10,
output_dir=dataset_dir,
generator_func=generate_instances_from_mist_models,
file_prefix="trainset",
valid_ratio=0.,
test_ratio=1.,
max_workers=5,
save_param_csvs=True,
verbose=True,
simulate=False)

# TODO: Update plot_trainset_histograms so that we can change name of the csv/dataset files
# Histograms are generated from the CSV files (as they cover params not in the dataset)
plots.plot_trainset_histograms(dataset_dir, dataset_dir/"synth-histogram-full.png", cols=4)
plots.plot_trainset_histograms(dataset_dir, dataset_dir/"synth-histogram-main.eps", cols=2,
params=["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"])

# Simple diagnostic plot of the mags feature of a small sample of the instances.
dataset_files = sorted(dataset_dir.glob("**/*.tfrecord"))
ids, _, _, _ = deb_example.read_dataset(dataset_files)
fig = plots.plot_dataset_instance_mags_features(dataset_files, ids[:30])
fig.savefig(dataset_dir / "sample.pdf")
67 changes: 31 additions & 36 deletions make_training_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from pathlib import Path
from contextlib import redirect_stdout
import hashlib

import numpy as np

Expand Down Expand Up @@ -34,7 +35,6 @@
# - it's a convenient break in the process

DATASET_SIZE = 250000
RESUME = False
dataset_dir = Path(f"./datasets/formal-training-dataset-{DATASET_SIZE // 1000}k/")
dataset_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -51,32 +51,36 @@ def generate_instances_from_distributions(instance_count: int, label: str, verbo
# pylint: disable=too-many-locals, invalid-name
generated_counter = 0
usable_counter = 0
set_id = label.replace("trainset", "")
set_id = ''.join(filter(str.isdigit, label))

# Don't use the built-in hash() function; it's not consistent across processes!!!
seed = int.from_bytes(hashlib.shake_128(label.encode("utf8")).digest(8))
rng = np.random.default_rng(seed)

while usable_counter < instance_count:
while True: # imitate "loop and a half" / "repeat ... until" logic
# These are the "label" params for which we have defined distributions
rA_plus_rB = np.random.uniform(low=0.001, high=0.45001)
k = np.random.normal(loc=0.8, scale=0.4)
inc = np.random.uniform(low=50., high=90.00001) * u.deg
J = np.random.normal(loc=0.8, scale=0.4)
rA_plus_rB = rng.uniform(low=0.001, high=0.45001)
k = rng.normal(loc=0.8, scale=0.4)
inc = rng.uniform(low=50., high=90.00001) * u.deg
J = rng.normal(loc=0.8, scale=0.4)

# We need a version of JKTEBOP which supports negative L3 input values
# (not so for version 43) in order to train a model to predict L3.
L3 = np.random.normal(0., 0.1)
L3 = rng.normal(0., 0.1)
L3 = 0 # continue to override until revised JKTEBOP released

# The qphot mass ratio value (MB/MA) affects the lightcurve via the ellipsoidal effect
# due to distortion of the stars' shape. Set to -100 to force spherical stars or derive
# a value from other params. We're using the k-q relations of Demircan & Kahraman (1991)
# Both <1.66 M_sun (k=q^0.935), both >1.66 M_sun (k=q^0.542), MB-low/MA-high (k=q^0.724)
# and approx' single rule is k = q^0.715 which we use here (tests find this works best).
qphot = np.random.normal(loc=k**1.4, scale=0.3) if k > 0 else 0
qphot = rng.normal(loc=k**1.4, scale=0.3) if k > 0 else 0

# We generate ecc and omega (argument of periastron) from appropriate distributions.
# They're not used directly as labels, but they make up ecosw and esinw which are.
ecc = np.abs(np.random.normal(loc=0.0, scale=0.2))
omega = np.random.uniform(low=0., high=360.) * u.deg
ecc = np.abs(rng.normal(loc=0.0, scale=0.2))
omega = rng.uniform(low=0., high=360.) * u.deg

# Now we can calculate the derived values, sufficient to check we've a usable system
inc_rad = inc.to(u.rad).value
Expand Down Expand Up @@ -132,29 +136,20 @@ def generate_instances_from_distributions(instance_count: int, label: str, verbo
# samples parameter distributions over JKTEBOP's usable range.
# ------------------------------------------------------------------------------
if __name__ == "__main__":
with redirect_stdout(Tee(open(dataset_dir / "trainset.log",
"w",
encoding="utf8"))):
datasets.generate_dataset_csvs(instance_count=DATASET_SIZE,
file_count=DATASET_SIZE // 10000,
output_dir=dataset_dir,
generator_func=generate_instances_from_distributions,
file_pattern="trainset{0:03d}.csv",
verbose=True,
simulate=False)

plots.plot_trainset_histograms(dataset_dir, dataset_dir / "train-histogram-full.png", cols=3)
plots.plot_trainset_histograms(dataset_dir, dataset_dir / "train-histogram-main.eps", cols=2,
params=["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"])

with redirect_stdout(Tee(open(dataset_dir / "dataset.log",
"a" if RESUME else "w",
encoding="utf8"))):
datasets.make_dataset_files(trainset_files=sorted(dataset_dir.glob("trainset*.csv")),
output_dir=dataset_dir,
valid_ratio=0.2,
test_ratio=0,
resume=RESUME,
max_workers=5,
verbose=True,
simulate=False)

with redirect_stdout(Tee(open(dataset_dir/"dataset.log", "w", encoding="utf8"))):
datasets.make_dataset(instance_count=DATASET_SIZE,
file_count=DATASET_SIZE // 10000,
output_dir=dataset_dir,
generator_func=generate_instances_from_distributions,
file_prefix="trainset",
valid_ratio=0.2,
test_ratio=0,
max_workers=5,
save_param_csvs=True,
verbose=True,
simulate=False)

plots.plot_trainset_histograms(dataset_dir, dataset_dir/"train-histogram-full.png", cols=3)
plots.plot_trainset_histograms(dataset_dir, dataset_dir/"train-histogram-main.eps", cols=2,
params=["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"])
Loading

0 comments on commit de37f60

Please sign in to comment.