diff --git a/config/model_config.yaml b/config/model_config.yaml index 7a4e605..ba8f983 100644 --- a/config/model_config.yaml +++ b/config/model_config.yaml @@ -1,10 +1,10 @@ -device: 0 # 0, cpu +device: 1 # 0, cpu seq_len: 96 # should not be changed for the current datasets input_dim: 1 # or 1 depending on user, but is dynamically set noise_dim: 256 cond_emb_dim: 64 shuffle: True -sparse_conditioning_loss_weight: 0.8 # sparse conditioning training sample weight for loss computation [0, 1] +sparse_conditioning_loss_weight: 0.5 # sparse conditioning training sample weight for loss computation [0, 1] freeze_cond_after_warmup: False # specify whether to freeze conditioning module parameters after warmup epochs save_cycle: 200 # specify number of epochs to save model after diff --git a/eval/evaluator.py b/eval/evaluator.py index 6475cda..235e888 100644 --- a/eval/evaluator.py +++ b/eval/evaluator.py @@ -323,19 +323,22 @@ def create_visualizations( range_plot = plot_range_with_syn_values( real_data_df, generated_samples_df, month, weekday ) - writer.add_figure(f"Visualizations/Range_Plot_{i}", range_plot) + if range_plot is not None: + writer.add_figure(f"Visualizations/Range_Plot_{i}", range_plot) # Visualization 2: Plot closest real signals with synthetic values closest_plot = plot_syn_with_closest_real_ts( real_data_df, generated_samples_df, month, weekday ) - writer.add_figure(f"Visualizations/Closest_Real_TS_{i}", closest_plot) + if closest_plot is not None: + writer.add_figure(f"Visualizations/Closest_Real_TS_{i}", closest_plot) # Visualization 3: KDE plots for real and synthetic data real_data_array = np.stack(real_data_df["timeseries"]) syn_data_array = np.stack(syn_data_df["timeseries"]) kde_plot = visualization(real_data_array, syn_data_array, "kernel") - writer.add_figure(f"Visualizations/KDE", kde_plot) + if kde_plot is None: + writer.add_figure(f"Visualizations/KDE", kde_plot) def get_trained_model(self, dataset: Any) -> Any: """ diff --git a/eval/metrics.py b/eval/metrics.py index 39d3fc7..b7f6917 100644 --- a/eval/metrics.py +++ b/eval/metrics.py @@ -384,7 +384,7 @@ def plot_range_with_syn_values( alpha=0.6, ) - plt.title(f"Range of Values and Synthetic Data for {weekday_name} in {month_name}") + plt.title(f"Range of Values and Synthetic Data for {weekday_name}s in {month_name}") plt.xlabel("Time of day") plt.ylabel("Electric load in kWh") @@ -394,8 +394,7 @@ def plot_range_with_syn_values( plt.legend() plt.tight_layout() - plt.show() - return f + return plt.gcf() def plot_syn_with_closest_real_ts( @@ -432,7 +431,7 @@ def plot_syn_with_closest_real_ts( syn_values = np.array([ts[:, dimension] for ts in syn_filtered_df["timeseries"]]) # Generate timestamps at 15-minute intervals - timestamps = pd.date_range(start="00:00", end="23:45", freq="15T") + timestamps = pd.date_range(start="00:00", end="23:45", freq="15min") hourly_positions, hourly_labels = get_hourly_ticks(timestamps) month_name, weekday_name = get_month_weekday_names(month, weekday) @@ -503,7 +502,7 @@ def plot_syn_with_closest_real_ts( ) plt.title( - f"Synthetic vs Closest Real Time Series for {weekday_name} in {month_name}" + f"Synthetic vs Closest Real Time Series for {weekday_name}s in {month_name}" ) plt.xlabel("Time of day") plt.ylabel("Electric load in kWh") @@ -513,5 +512,4 @@ def plot_syn_with_closest_real_ts( plt.legend() plt.tight_layout() - plt.show() - return f + return plt.gcf() diff --git a/main.py b/main.py index d26e4e0..fbda225 100644 --- a/main.py +++ b/main.py @@ -46,24 +46,24 @@ def main(): # evaluate_individual_user_models("gpt", include_generation=False) # evaluate_individual_user_models("acgan", include_generation=True) # evaluate_individual_user_models("acgan", include_generation=False, normalization_method="date") - # evaluate_single_dataset_model( - # "diffusion_ts", - # # geography="california", - # include_generation=True, - # normalization_method="group", - # ) - dataset_manager = PecanStreetDataManager( + evaluate_single_dataset_model( + "diffusion_ts", geography="california", - normalize=True, include_generation=True, normalization_method="group", - threshold=(-5, 5), ) - non_pv_user_dataset = dataset_manager.create_non_pv_user_dataset() - generator = DataGenerator("diffusion_ts") - generator.load("checkpoints/2024-10-15_05-09-19/diffusion_ts_checkpoint_1000.pt") - evaluator = Evaluator(non_pv_user_dataset, "diffusion_ts") - evaluator.evaluate_model(model=generator.model, data_label="pre-loaded") + # dataset_manager = PecanStreetDataManager( + # geography="california", + # normalize=True, + # include_generation=True, + # normalization_method="group", + # threshold=(-5, 5), + # ) + # non_pv_user_dataset = dataset_manager.create_non_pv_user_dataset() + # generator = DataGenerator("diffusion_ts") + # generator.load("checkpoints/2024-10-15_05-09-44/diffusion_ts_checkpoint_1000.pt") + # evaluator = Evaluator(non_pv_user_dataset, "diffusion_ts") + # evaluator.evaluate_model(model=generator.model, data_label="pre-loaded") if __name__ == "__main__":