diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 95f37edfe..64846bd91 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10"] pytorch-version: ["2.0"] runs-on: "ubuntu-latest" diff --git a/graphium/finetuning/finetuning_architecture.py b/graphium/finetuning/finetuning_architecture.py index 8fd6263b5..10f918621 100644 --- a/graphium/finetuning/finetuning_architecture.py +++ b/graphium/finetuning/finetuning_architecture.py @@ -1,25 +1,13 @@ -from typing import Iterable, List, Dict, Tuple, Union, Callable, Any, Optional, Type - -from copy import deepcopy - -from loguru import logger +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn - from torch import Tensor from torch_geometric.data import Batch -from graphium.data.utils import get_keys -from graphium.nn.base_graph_layer import BaseGraphStructure -from graphium.nn.architectures.encoder_manager import EncoderManager -from graphium.nn.architectures import FullGraphMultiTaskNetwork, FeedForwardNN, FeedForwardPyg, TaskHeads -from graphium.nn.architectures.global_architectures import FeedForwardGraph -from graphium.trainer.predictor_options import ModelOptions from graphium.nn.utils import MupMixin - from graphium.trainer.predictor import PredictorModule -from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT, FINETUNING_HEADS_DICT +from graphium.utils.spaces import FINETUNING_HEADS_DICT, GRAPHIUM_PRETRAINED_MODELS_DICT class FullGraphFinetuningNetwork(nn.Module, MupMixin): @@ -318,7 +306,7 @@ def __init__(self, finetuning_head_kwargs: Dict[str, Any]): self.net = net(**finetuning_head_kwargs) def forward(self, g: Union[Dict[str, Union[torch.Tensor, Batch]], torch.Tensor, Batch]): - if isinstance(g, Union[torch.Tensor, Batch]): + if isinstance(g, (torch.Tensor, Batch)): pass elif isinstance(g, Dict) and len(g) == 1: g = list(g.values())[0]