Skip to content

Commit

Permalink
dataset histograms now support any csv file names (relates to #79)
Browse files Browse the repository at this point in the history
Previously the plot_trainset_histogram() function was hard coded to look for trainset*.csv files. Its replacement now takes an iterable of files, so it's up to the caller to glob the appropriate files (which can now be called whatever is wanted).

Took the opportunity to further revise this to return a fig as the other plot functions in this module do.
  • Loading branch information
SteveOv committed Aug 2, 2024
1 parent 628a40d commit db08e0c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 37 deletions.
17 changes: 8 additions & 9 deletions make_synthetic_test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ebop_maven.libs.tee import Tee

DATASET_SIZE = 20000
RESUME = False
FILE_PREFIX = "trainset"
dataset_dir = Path("./datasets/synthetic-mist-tess-dataset/")
dataset_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -216,28 +216,27 @@ def generate_instances_from_mist_models(label: str, verbose: bool=False):
# which generates random plausible dEB systems based on MIST stellar models.
# ------------------------------------------------------------------------------
if __name__ == "__main__":

with redirect_stdout(Tee(open(dataset_dir/"dataset.log", "w", encoding="utf8"))):
datasets.make_dataset(instance_count=DATASET_SIZE,
file_count=10,
output_dir=dataset_dir,
generator_func=generate_instances_from_mist_models,
file_prefix="trainset",
file_prefix=FILE_PREFIX,
valid_ratio=0.,
test_ratio=1.,
max_workers=5,
save_param_csvs=True,
verbose=True,
simulate=False)

# TODO: Update plot_trainset_histograms so that we can change name of the csv/dataset files
# Histograms are generated from the CSV files (as they cover params not in the dataset)
plots.plot_trainset_histograms(dataset_dir, dataset_dir/"synth-histogram-full.png", cols=4)
plots.plot_trainset_histograms(dataset_dir, dataset_dir/"synth-histogram-main.eps", cols=2,
params=["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"])
# Histograms are generated from the CSV files as they cover params not saved to tfrecord
csvs = sorted(dataset_dir.glob(f"**/{FILE_PREFIX}*.csv"))
plots.plot_dataset_histograms(csvs, cols=4).savefig(dataset_dir/"synth-histogram-full.png")
plots.plot_dataset_histograms(csvs, ["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"],
cols=2).savefig(dataset_dir/"synth-histogram-main.eps")

# Simple diagnostic plot of the mags feature of a small sample of the instances.
dataset_files = sorted(dataset_dir.glob("**/*.tfrecord"))
dataset_files = sorted(dataset_dir.glob(f"**/{FILE_PREFIX}*.tfrecord"))
ids, _, _, _ = deb_example.read_dataset(dataset_files)
fig = plots.plot_dataset_instance_mags_features(dataset_files, ids[:30])
fig.savefig(dataset_dir / "sample.pdf")
11 changes: 7 additions & 4 deletions make_training_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ebop_maven.libs.tee import Tee

DATASET_SIZE = 250000
FILE_PREFIX = "trainset"
dataset_dir = Path(f"./datasets/formal-training-dataset-{DATASET_SIZE // 1000}k/")
dataset_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -129,14 +130,16 @@ def generate_instances_from_distributions(label: str, verbose: bool=False):
file_count=DATASET_SIZE // 10000,
output_dir=dataset_dir,
generator_func=generate_instances_from_distributions,
file_prefix="trainset",
file_prefix=FILE_PREFIX,
valid_ratio=0.2,
test_ratio=0,
max_workers=5,
save_param_csvs=True,
verbose=True,
simulate=False)

plots.plot_trainset_histograms(dataset_dir, dataset_dir/"train-histogram-full.png", cols=3)
plots.plot_trainset_histograms(dataset_dir, dataset_dir/"train-histogram-main.eps", cols=2,
params=["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"])
# Histograms are generated from the CSV files as they cover params not saved to tfrecord
csvs = sorted(dataset_dir.glob(f"**/{FILE_PREFIX}*.csv"))
plots.plot_dataset_histograms(csvs, cols=3).savefig(dataset_dir/"train-histogram-full.png")
plots.plot_dataset_histograms(csvs, ["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"],
cols=2).savefig(dataset_dir/"train-histogram-main.eps")
42 changes: 18 additions & 24 deletions traininglib/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,42 +54,41 @@
}


def plot_trainset_histograms(trainset_dir: Path,
plot_file: Path=None,
params: List[str]=None,
cols: int=3,
yscale: str="log",
verbose: bool=True):
def plot_dataset_histograms(csv_files: Iterable[Path],
params: List[str]=None,
cols: int=3,
yscale: str="log",
verbose: bool=True):
"""
Saves histogram plots to a single figure on a grid of axes. The params will
be plotted in the order they are listed, scanning from left to right and down.
Saves histogram plots to a single figure on a grid of axes. The params will be plotted
in the order they are listed, scanning from left to right and down. These are generated
from the dataset's CSV files as they may plot params not written to the dataset tfrecords.
:trainset_dir: the directory containing the trainset csv files
:plot_file: the directory to save the plots. If none, they're saved with the trainset
:parameters: the list of parameters to plot, or the full list if None.
:csv_files: a list of the dataset's csv files
:parames: the list of parameters to plot, or the full list if None.
See the histogram_parameters attribute for the full list
:cols: the width of the axes grid (the rows automatically adjust)
:yscale: set to "linear" or "log" to control the y-axis scale
:verbose: whether to print verbose progress/diagnostic messages
"""
# pylint: disable=too-many-arguments
csvs = sorted(trainset_dir.glob("trainset*.csv"))

csv_files = sorted(csv_files) # Happy for this to error if there's a problem
if not params:
params = get_field_names_from_csvs(csvs)
params = get_field_names_from_csvs(csv_files)
param_specs = { p: all_histogram_params[p] for p in params if p in all_histogram_params }

if param_specs and csvs:
fig = None
if param_specs and csv_files:
rows = math.ceil(len(param_specs) / cols)
_, axes = plt.subplots(rows, cols, sharey="all",
figsize=(cols*3, rows*2.5), constrained_layout=True)
fig, axes = plt.subplots(rows, cols, sharey="all",
figsize=(cols*3, rows*2.5), constrained_layout=True)
if verbose:
print(f"Plotting histograms in a {cols}x{rows} grid for:", ", ".join(param_specs))

for (ax, field) in zip_longest(axes.flatten(), param_specs):
if field:
bins, label = param_specs[field]
data = [row.get(field, None) for row in read_param_sets_from_csvs(csvs)]
data = [row.get(field, None) for row in read_param_sets_from_csvs(csv_files)]
if verbose:
print(f"Plotting histogram for {len(data):,} {field} values.")
ax.hist(data, bins=bins)
Expand All @@ -99,12 +98,7 @@ def plot_trainset_histograms(trainset_dir: Path,
ax.set_yscale(yscale)
else:
ax.axis("off") # remove the unused ax

if verbose:
print("Saving histogram plot to", plot_file)
plot_file.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(plot_file, dpi=100) # dpi is ignored for vector formats
plt.close()
return fig


def plot_formal_test_dataset_hr_diagram(targets_cfg: Dict[str, any],
Expand Down

0 comments on commit db08e0c

Please sign in to comment.