Skip to content

Commit

Permalink
Refactor gnn training (#249)
Browse files Browse the repository at this point in the history
* minor upds

* refactor: training pipeline

* feat: find gcs image path

* feat: feature generation in trainer

* feat: validation sets in training

* bug: hgraph forward passes with missing edge types

* refactor: hgnn trainer

* feat: functional training pipeline

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Sep 27, 2024
1 parent b081e60 commit c86bdc6
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 440 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
WEIGHT_DECAY = 1e-3


class HeteroGraphTrainer:
class Trainer:
"""
Custom class that trains graph neural networks.
Expand Down Expand Up @@ -107,9 +107,9 @@ def run(self, train_dataset_list, validation_dataset_list):
# Train
y, hat_y = [], []
self.model.train()
for graph_dataset in train_dataset_list:
for dataset in train_dataset_list:
# Forward pass
hat_y_i, y_i = self.predict(graph_dataset.data)
hat_y_i, y_i = self.predict(dataset.data)
loss = self.criterion(hat_y_i, y_i)
self.writer.add_scalar("loss", loss, epoch)

Expand All @@ -129,8 +129,8 @@ def run(self, train_dataset_list, validation_dataset_list):
if epoch % 10 == 0:
y, hat_y = [], []
self.model.eval()
for graph_dataset in validation_dataset_list:
hat_y_i, y_i = self.predict(graph_dataset.data)
for dataset in validation_dataset_list:
hat_y_i, y_i = self.predict(dataset.data)
y.extend(toCPU(y_i))
hat_y.extend(toCPU(hat_y_i))
test_score = self.compute_metrics(y, hat_y, "val", epoch)
Expand Down
Loading

0 comments on commit c86bdc6

Please sign in to comment.