From 0d0d9cada4006352ed711b18046dc837867afef0 Mon Sep 17 00:00:00 2001 From: Florian Kiwit Date: Wed, 22 Nov 2023 18:36:42 +0100 Subject: [PATCH] Fix problem with continuous dataset --- .../data/data_handler/DataHandler.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/modules/applications/QML/generative_modeling/data/data_handler/DataHandler.py b/src/modules/applications/QML/generative_modeling/data/data_handler/DataHandler.py index cf864e0b..15492edd 100644 --- a/src/modules/applications/QML/generative_modeling/data/data_handler/DataHandler.py +++ b/src/modules/applications/QML/generative_modeling/data/data_handler/DataHandler.py @@ -125,10 +125,13 @@ def postprocess(self, input_data: dict, config: dict, **kwargs): if self.generalization_mark is not None: np.save(f"{store_dir_iter}/histogram_generated.npy", evaluation["histogram_generated"]) else: - samples = input_data["best_sample"] - n_shots = np.sum(samples) - histogram_generated = np.asarray(samples) / n_shots - histogram_generated[histogram_generated == 0] = 1e-8 + if "best_sample" in list(input_data.keys()): + samples = input_data["best_sample"] + n_shots = np.sum(samples) + histogram_generated = np.asarray(samples) / n_shots + histogram_generated[histogram_generated == 0] = 1e-8 + else: + histogram_generated = input_data["histogram_generated"] np.save(f"{store_dir_iter}/histogram_generated.npy", histogram_generated) self.metrics.add_metric_batch({"histogram_generated": os.path.relpath( f"{store_dir_iter}/histogram_generated.npy_{kwargs['rep_count']}.npy", current_directory)})