From db08e0c8fd76e2c572922f4bc839dc5fb9bd5699 Mon Sep 17 00:00:00 2001 From: SteveOv Date: Fri, 2 Aug 2024 17:50:51 +0100 Subject: [PATCH] dataset histograms now support any csv file names (relates to #79) Previously the plot_trainset_histogram() function was hard coded to look for trainset*.csv files. Its replacement now takes an iterable of files, so it's up to the caller to glob the appropriate files (which can now be called whatever is wanted). Took the opportunity to further revise this to return a fig as the other plot functions in this module do. --- make_synthetic_test_dataset.py | 17 +++++++------- make_training_dataset.py | 11 +++++---- traininglib/plots.py | 42 +++++++++++++++------------------- 3 files changed, 33 insertions(+), 37 deletions(-) diff --git a/make_synthetic_test_dataset.py b/make_synthetic_test_dataset.py index 949d080..453c7f6 100644 --- a/make_synthetic_test_dataset.py +++ b/make_synthetic_test_dataset.py @@ -25,7 +25,7 @@ from ebop_maven.libs.tee import Tee DATASET_SIZE = 20000 -RESUME = False +FILE_PREFIX = "trainset" dataset_dir = Path("./datasets/synthetic-mist-tess-dataset/") dataset_dir.mkdir(parents=True, exist_ok=True) @@ -216,13 +216,12 @@ def generate_instances_from_mist_models(label: str, verbose: bool=False): # which generates random plausible dEB systems based on MIST stellar models. # ------------------------------------------------------------------------------ if __name__ == "__main__": - with redirect_stdout(Tee(open(dataset_dir/"dataset.log", "w", encoding="utf8"))): datasets.make_dataset(instance_count=DATASET_SIZE, file_count=10, output_dir=dataset_dir, generator_func=generate_instances_from_mist_models, - file_prefix="trainset", + file_prefix=FILE_PREFIX, valid_ratio=0., test_ratio=1., max_workers=5, @@ -230,14 +229,14 @@ def generate_instances_from_mist_models(label: str, verbose: bool=False): verbose=True, simulate=False) - # TODO: Update plot_trainset_histograms so that we can change name of the csv/dataset files - # Histograms are generated from the CSV files (as they cover params not in the dataset) - plots.plot_trainset_histograms(dataset_dir, dataset_dir/"synth-histogram-full.png", cols=4) - plots.plot_trainset_histograms(dataset_dir, dataset_dir/"synth-histogram-main.eps", cols=2, - params=["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"]) + # Histograms are generated from the CSV files as they cover params not saved to tfrecord + csvs = sorted(dataset_dir.glob(f"**/{FILE_PREFIX}*.csv")) + plots.plot_dataset_histograms(csvs, cols=4).savefig(dataset_dir/"synth-histogram-full.png") + plots.plot_dataset_histograms(csvs, ["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"], + cols=2).savefig(dataset_dir/"synth-histogram-main.eps") # Simple diagnostic plot of the mags feature of a small sample of the instances. - dataset_files = sorted(dataset_dir.glob("**/*.tfrecord")) + dataset_files = sorted(dataset_dir.glob(f"**/{FILE_PREFIX}*.tfrecord")) ids, _, _, _ = deb_example.read_dataset(dataset_files) fig = plots.plot_dataset_instance_mags_features(dataset_files, ids[:30]) fig.savefig(dataset_dir / "sample.pdf") diff --git a/make_training_dataset.py b/make_training_dataset.py index 95448bc..2118a5d 100644 --- a/make_training_dataset.py +++ b/make_training_dataset.py @@ -22,6 +22,7 @@ from ebop_maven.libs.tee import Tee DATASET_SIZE = 250000 +FILE_PREFIX = "trainset" dataset_dir = Path(f"./datasets/formal-training-dataset-{DATASET_SIZE // 1000}k/") dataset_dir.mkdir(parents=True, exist_ok=True) @@ -129,7 +130,7 @@ def generate_instances_from_distributions(label: str, verbose: bool=False): file_count=DATASET_SIZE // 10000, output_dir=dataset_dir, generator_func=generate_instances_from_distributions, - file_prefix="trainset", + file_prefix=FILE_PREFIX, valid_ratio=0.2, test_ratio=0, max_workers=5, @@ -137,6 +138,8 @@ def generate_instances_from_distributions(label: str, verbose: bool=False): verbose=True, simulate=False) - plots.plot_trainset_histograms(dataset_dir, dataset_dir/"train-histogram-full.png", cols=3) - plots.plot_trainset_histograms(dataset_dir, dataset_dir/"train-histogram-main.eps", cols=2, - params=["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"]) + # Histograms are generated from the CSV files as they cover params not saved to tfrecord + csvs = sorted(dataset_dir.glob(f"**/{FILE_PREFIX}*.csv")) + plots.plot_dataset_histograms(csvs, cols=3).savefig(dataset_dir/"train-histogram-full.png") + plots.plot_dataset_histograms(csvs, ["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"], + cols=2).savefig(dataset_dir/"train-histogram-main.eps") diff --git a/traininglib/plots.py b/traininglib/plots.py index 6f7eb5b..660f9f2 100644 --- a/traininglib/plots.py +++ b/traininglib/plots.py @@ -54,42 +54,41 @@ } -def plot_trainset_histograms(trainset_dir: Path, - plot_file: Path=None, - params: List[str]=None, - cols: int=3, - yscale: str="log", - verbose: bool=True): +def plot_dataset_histograms(csv_files: Iterable[Path], + params: List[str]=None, + cols: int=3, + yscale: str="log", + verbose: bool=True): """ - Saves histogram plots to a single figure on a grid of axes. The params will - be plotted in the order they are listed, scanning from left to right and down. + Saves histogram plots to a single figure on a grid of axes. The params will be plotted + in the order they are listed, scanning from left to right and down. These are generated + from the dataset's CSV files as they may plot params not written to the dataset tfrecords. - :trainset_dir: the directory containing the trainset csv files - :plot_file: the directory to save the plots. If none, they're saved with the trainset - :parameters: the list of parameters to plot, or the full list if None. + :csv_files: a list of the dataset's csv files + :parames: the list of parameters to plot, or the full list if None. See the histogram_parameters attribute for the full list :cols: the width of the axes grid (the rows automatically adjust) :yscale: set to "linear" or "log" to control the y-axis scale :verbose: whether to print verbose progress/diagnostic messages """ # pylint: disable=too-many-arguments - csvs = sorted(trainset_dir.glob("trainset*.csv")) - + csv_files = sorted(csv_files) # Happy for this to error if there's a problem if not params: - params = get_field_names_from_csvs(csvs) + params = get_field_names_from_csvs(csv_files) param_specs = { p: all_histogram_params[p] for p in params if p in all_histogram_params } - if param_specs and csvs: + fig = None + if param_specs and csv_files: rows = math.ceil(len(param_specs) / cols) - _, axes = plt.subplots(rows, cols, sharey="all", - figsize=(cols*3, rows*2.5), constrained_layout=True) + fig, axes = plt.subplots(rows, cols, sharey="all", + figsize=(cols*3, rows*2.5), constrained_layout=True) if verbose: print(f"Plotting histograms in a {cols}x{rows} grid for:", ", ".join(param_specs)) for (ax, field) in zip_longest(axes.flatten(), param_specs): if field: bins, label = param_specs[field] - data = [row.get(field, None) for row in read_param_sets_from_csvs(csvs)] + data = [row.get(field, None) for row in read_param_sets_from_csvs(csv_files)] if verbose: print(f"Plotting histogram for {len(data):,} {field} values.") ax.hist(data, bins=bins) @@ -99,12 +98,7 @@ def plot_trainset_histograms(trainset_dir: Path, ax.set_yscale(yscale) else: ax.axis("off") # remove the unused ax - - if verbose: - print("Saving histogram plot to", plot_file) - plot_file.parent.mkdir(parents=True, exist_ok=True) - plt.savefig(plot_file, dpi=100) # dpi is ignored for vector formats - plt.close() + return fig def plot_formal_test_dataset_hr_diagram(targets_cfg: Dict[str, any],