Skip to content

Commit

Permalink
Merge pull request #36 from NWC-CUAHSI-Summer-Institute/debug_ML
Browse files Browse the repository at this point in the history
Debug ml
  • Loading branch information
RY4GIT authored Nov 26, 2023
2 parents c41a8f1 + 6c1e034 commit e941e2f
Show file tree
Hide file tree
Showing 12 changed files with 70,310 additions and 23 deletions.
8,784 changes: 8,784 additions & 0 deletions data/synthetic_case/DEEP_GW_TO_CHANNEL_FLUX.csv

Large diffs are not rendered by default.

8,784 changes: 8,784 additions & 0 deletions data/synthetic_case/DIRECT_RUNOFF.csv

Large diffs are not rendered by default.

8,784 changes: 8,784 additions & 0 deletions data/synthetic_case/GIUH_RUNOFF.csv

Large diffs are not rendered by default.

8,784 changes: 8,784 additions & 0 deletions data/synthetic_case/NASH_LATERAL_RUNOFF.csv

Large diffs are not rendered by default.

8,786 changes: 8,785 additions & 1 deletion data/synthetic_case/SOIL_CONCEPTUAL_STORAGE.csv

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions data/synthetic_case/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,20 @@ data:
json_params_dir: ${data.data_dir}\cat_{}_bmi_config_cfe.json
partition_scheme: Schaake
start_time: '2010-10-01 00:00:00'
end_time: '2011-09-30 23:00:00'
end_time: '2012-09-30 23:00:00'
models:
hyperparameters:
epochs: 10
learning_rate: 1
epochs: 30
learning_rate: 0.01
warmup: 2000
step_size: 1
gamma: 0.7
gamma: 0.92
mlp:
hidden_size: 6
hidden_size: 256
num_attrs: 3
num_params: 2
num_states: 2
lag_hrs: 24
transformation:
Cgw:
- 1.8e-06
Expand All @@ -56,4 +57,4 @@ models:
- 0.000726
initial_params:
Cgw: 0.000255
satdk: 1.82e-08
satdk: 0.0182
8,784 changes: 8,784 additions & 0 deletions data/synthetic_case/land_surface_water__runoff_depth.csv

Large diffs are not rendered by default.

8,784 changes: 8,784 additions & 0 deletions data/synthetic_case/land_surface_water__runoff_volume_flux.csv

Large diffs are not rendered by default.

8,784 changes: 8,784 additions & 0 deletions data/synthetic_case/synthetic_classic.csv

Large diffs are not rendered by default.

38 changes: 25 additions & 13 deletions src/agents/DifferentiableCFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from data.metrics import calculate_nse
from models.dCFE import dCFE
from utils.ddp_setup import find_free_port, cleanup

import shutil

log = logging.getLogger("agents.DifferentiableCFE")

Expand Down Expand Up @@ -77,6 +77,9 @@ def __init__(self, cfg: DictConfig) -> None:
]
)

self.Cgw_record = np.empty([self.data.num_basins, self.data.n_timesteps])
self.satdk_record = np.empty([self.data.num_basins, self.data.n_timesteps])

# # Prepare for the DDP
# free_port = find_free_port()
# os.environ["MASTER_ADDR"] = "localhost"
Expand All @@ -86,8 +89,9 @@ def create_output_dir(self):
current_date = datetime.now().strftime("%Y-%m-%d")
dir_name = f"{current_date}_output"
output_dir = os.path.join(self.cfg.output_dir, dir_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
return output_dir

def run(self):
Expand Down Expand Up @@ -126,6 +130,7 @@ def train(self) -> None:
self.run_model(run_mlp=False)

for epoch in range(1, self.cfg.models.hyperparameters.epochs + 1):
# TODO: Loop through basins
log.info(f"Epoch #: {epoch}/{self.cfg.models.hyperparameters.epochs}")
self.loss_record[epoch - 1] = self.train_one_epoch()
self.current_epoch += 1
Expand Down Expand Up @@ -167,6 +172,8 @@ def run_model(self, run_mlp=False):
for t, (x, y_t) in enumerate(tqdm(self.data_loader, desc="Processing data")):
if run_mlp:
self.model.mlp_forward(t) # Instead
self.Cgw_record[:, t] = self.model.Cgw.detach().numpy()
self.satdk_record[:, t] = self.model.satdk.detach().numpy()
runoff = self.model(x, t)
y_hat[:, t] = runoff

Expand Down Expand Up @@ -275,12 +282,17 @@ def save_loss(self):
df.to_csv(file_path)

fig, axes = plt.subplots()
axes.plot(self.loss_record, "-")
# Create the x-axis values
epoch_list = list(range(1, len(self.loss_record) + 1))

# Plotting
fig, axes = plt.subplots()
axes.plot(epoch_list, self.loss_record, "-")
axes.set_title(
f"Initial learning rate: {self.cfg.models.hyperparameters.learning_rate}"
)
axes.set_ylabel("loss")
axes.set_xlabel("epoch-1")
axes.set_xlabel("epoch")
fig.tight_layout()
plt.savefig(os.path.join(self.output_dir, f"final_result_loss.png"))
plt.close()
Expand All @@ -305,19 +317,15 @@ def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best=0):
def save_result(self, y_hat, y_t, out_filename, plot_figure=False):
# Save all basin runs

Cgw = self.model.Cgw.detach().numpy()
satdk_ = self.model.satdk.detach().numpy()

warmup = self.cfg.models.hyperparameters.warmup

for i, basin_id in enumerate(self.data.basin_ids):
# Save the timeseries of runoff and the best dynamic parametersers
# TODO: Save Cgw and satdk later
# "Cgw": Cgw[i, warmup:],
# "satdk": satdk_[i, warmup:],
data = {
"y_hat": y_hat[i, warmup:].squeeze(),
"y_t": y_t[i, warmup:].squeeze(),
"Cgw": self.Cgw_record[i, warmup:],
"satdk": self.satdk_record[i, warmup:],
}
df = pd.DataFrame(data)
df.to_csv(
Expand All @@ -329,8 +337,12 @@ def save_result(self, y_hat, y_t, out_filename, plot_figure=False):
# Plot
eval_metrics = he.evaluator(he.kge, y_hat[i], y_t[i])[0]
fig, axes = plt.subplots(figsize=(5, 5))
axes.plot(y_t[i, warmup:], "-", label="eval (synthetic)", alpha=0.5)
axes.plot(y_hat[i, warmup:], "--", label="sim (recovery)", alpha=0.5)
if self.cfg.run_type == "ML_synthetic_test":
eval_label = "evaluation (synthetic)"
elif self.cfg.run_type == "ML":
eval_label = "observed"
axes.plot(y_t[i, warmup:], "-", label=eval_label, alpha=0.5)
axes.plot(y_hat[i, warmup:], "--", label="predicted", alpha=0.5)
axes.set_title(f"Classic (KGE={float(eval_metrics):.2})")
plt.legend()
plt.savefig(
Expand Down
5 changes: 4 additions & 1 deletion src/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ num_processes: 1


################## CHANGE HERE #######################
run_type: ML_synthetic_test # Choose from: "generate_synthetic" or "ML" or "ML_synthetic_test"
run_type: ML # Choose from: "generate_synthetic" or "ML" or "ML_synthetic_test"
# ML: runn ML against observed data
# ML_synthetic_test: runn ML against synthetic data
# generate_synthetic: generate synthetic data for ML_synthetic_test
soil_scheme: classic # Choose from: "ode" or "classic"
######################################################

Expand Down
3 changes: 1 addition & 2 deletions src/data/Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def get_observations(self, cfg: DictConfig):
obs_q_ = pd.read_csv(cfg.data.compare_results_file.format(basin_id))
obs_q_.set_index(pd.to_datetime(obs_q_["date"]), inplace=True)
q = torch.tensor(
obs_q_["QObs(mm/h)"][self.start_time : self.end_time].copy().values
/ cfg.conversions.m_to_mm,
obs_q_["QObs(mm/h)"][self.start_time : self.end_time].copy().values,
device=cfg.device,
)
y_ = torch.stack([q])
Expand Down

0 comments on commit e941e2f

Please sign in to comment.