Skip to content

Commit

Permalink
eval fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Oct 17, 2024
1 parent bbb2226 commit 888d05b
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 26 deletions.
4 changes: 2 additions & 2 deletions config/model_config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
device: 0 # 0, cpu
device: 1 # 0, cpu
seq_len: 96 # should not be changed for the current datasets
input_dim: 1 # or 1 depending on user, but is dynamically set
noise_dim: 256
cond_emb_dim: 64
shuffle: True
sparse_conditioning_loss_weight: 0.8 # sparse conditioning training sample weight for loss computation [0, 1]
sparse_conditioning_loss_weight: 0.5 # sparse conditioning training sample weight for loss computation [0, 1]
freeze_cond_after_warmup: False # specify whether to freeze conditioning module parameters after warmup epochs
save_cycle: 200 # specify number of epochs to save model after

Expand Down
9 changes: 6 additions & 3 deletions eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,19 +323,22 @@ def create_visualizations(
range_plot = plot_range_with_syn_values(
real_data_df, generated_samples_df, month, weekday
)
writer.add_figure(f"Visualizations/Range_Plot_{i}", range_plot)
if range_plot is not None:
writer.add_figure(f"Visualizations/Range_Plot_{i}", range_plot)

# Visualization 2: Plot closest real signals with synthetic values
closest_plot = plot_syn_with_closest_real_ts(
real_data_df, generated_samples_df, month, weekday
)
writer.add_figure(f"Visualizations/Closest_Real_TS_{i}", closest_plot)
if closest_plot is not None:
writer.add_figure(f"Visualizations/Closest_Real_TS_{i}", closest_plot)

# Visualization 3: KDE plots for real and synthetic data
real_data_array = np.stack(real_data_df["timeseries"])
syn_data_array = np.stack(syn_data_df["timeseries"])
kde_plot = visualization(real_data_array, syn_data_array, "kernel")
writer.add_figure(f"Visualizations/KDE", kde_plot)
if kde_plot is None:
writer.add_figure(f"Visualizations/KDE", kde_plot)

def get_trained_model(self, dataset: Any) -> Any:
"""
Expand Down
12 changes: 5 additions & 7 deletions eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def plot_range_with_syn_values(
alpha=0.6,
)

plt.title(f"Range of Values and Synthetic Data for {weekday_name} in {month_name}")
plt.title(f"Range of Values and Synthetic Data for {weekday_name}s in {month_name}")
plt.xlabel("Time of day")
plt.ylabel("Electric load in kWh")

Expand All @@ -394,8 +394,7 @@ def plot_range_with_syn_values(

plt.legend()
plt.tight_layout()
plt.show()
return f
return plt.gcf()


def plot_syn_with_closest_real_ts(
Expand Down Expand Up @@ -432,7 +431,7 @@ def plot_syn_with_closest_real_ts(
syn_values = np.array([ts[:, dimension] for ts in syn_filtered_df["timeseries"]])

# Generate timestamps at 15-minute intervals
timestamps = pd.date_range(start="00:00", end="23:45", freq="15T")
timestamps = pd.date_range(start="00:00", end="23:45", freq="15min")

hourly_positions, hourly_labels = get_hourly_ticks(timestamps)
month_name, weekday_name = get_month_weekday_names(month, weekday)
Expand Down Expand Up @@ -503,7 +502,7 @@ def plot_syn_with_closest_real_ts(
)

plt.title(
f"Synthetic vs Closest Real Time Series for {weekday_name} in {month_name}"
f"Synthetic vs Closest Real Time Series for {weekday_name}s in {month_name}"
)
plt.xlabel("Time of day")
plt.ylabel("Electric load in kWh")
Expand All @@ -513,5 +512,4 @@ def plot_syn_with_closest_real_ts(

plt.legend()
plt.tight_layout()
plt.show()
return f
return plt.gcf()
28 changes: 14 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,24 @@ def main():
# evaluate_individual_user_models("gpt", include_generation=False)
# evaluate_individual_user_models("acgan", include_generation=True)
# evaluate_individual_user_models("acgan", include_generation=False, normalization_method="date")
# evaluate_single_dataset_model(
# "diffusion_ts",
# # geography="california",
# include_generation=True,
# normalization_method="group",
# )
dataset_manager = PecanStreetDataManager(
evaluate_single_dataset_model(
"diffusion_ts",
geography="california",
normalize=True,
include_generation=True,
normalization_method="group",
threshold=(-5, 5),
)
non_pv_user_dataset = dataset_manager.create_non_pv_user_dataset()
generator = DataGenerator("diffusion_ts")
generator.load("checkpoints/2024-10-15_05-09-19/diffusion_ts_checkpoint_1000.pt")
evaluator = Evaluator(non_pv_user_dataset, "diffusion_ts")
evaluator.evaluate_model(model=generator.model, data_label="pre-loaded")
# dataset_manager = PecanStreetDataManager(
# geography="california",
# normalize=True,
# include_generation=True,
# normalization_method="group",
# threshold=(-5, 5),
# )
# non_pv_user_dataset = dataset_manager.create_non_pv_user_dataset()
# generator = DataGenerator("diffusion_ts")
# generator.load("checkpoints/2024-10-15_05-09-44/diffusion_ts_checkpoint_1000.pt")
# evaluator = Evaluator(non_pv_user_dataset, "diffusion_ts")
# evaluator.evaluate_model(model=generator.model, data_label="pre-loaded")


if __name__ == "__main__":
Expand Down

0 comments on commit 888d05b

Please sign in to comment.