diff --git a/src/modules/applications/QML/generative_modeling/data/data_handler/DataHandler.py b/src/modules/applications/QML/generative_modeling/data/data_handler/DataHandler.py index cf864e0b..15492edd 100644 --- a/src/modules/applications/QML/generative_modeling/data/data_handler/DataHandler.py +++ b/src/modules/applications/QML/generative_modeling/data/data_handler/DataHandler.py @@ -125,10 +125,13 @@ def postprocess(self, input_data: dict, config: dict, **kwargs): if self.generalization_mark is not None: np.save(f"{store_dir_iter}/histogram_generated.npy", evaluation["histogram_generated"]) else: - samples = input_data["best_sample"] - n_shots = np.sum(samples) - histogram_generated = np.asarray(samples) / n_shots - histogram_generated[histogram_generated == 0] = 1e-8 + if "best_sample" in list(input_data.keys()): + samples = input_data["best_sample"] + n_shots = np.sum(samples) + histogram_generated = np.asarray(samples) / n_shots + histogram_generated[histogram_generated == 0] = 1e-8 + else: + histogram_generated = input_data["histogram_generated"] np.save(f"{store_dir_iter}/histogram_generated.npy", histogram_generated) self.metrics.add_metric_batch({"histogram_generated": os.path.relpath( f"{store_dir_iter}/histogram_generated.npy_{kwargs['rep_count']}.npy", current_directory)})