Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pytorch-ie to 0.23 #125

Merged
merged 4 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/dataset/_convert_documents.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
convert_documents:
_processor_: pytorch_ie.DatasetDict.map
_processor_: pytorch_ie.DatasetDict.to_document_type
document_type: ???
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# --------- pytorch-ie --------- #
pytorch-ie>=0.19.0,<1.0.0
pytorch-ie>=0.23.0,<1.0.0
# pie-utils provides some useful helper methods for pytorch-ie,
# e.g. document processors or span utils (convert span annotations
# to sequence encodings such as BIO, IO or BIOUL, and back).
Expand Down
69 changes: 42 additions & 27 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
from omegaconf import DictConfig
from pytorch_ie import DatasetDict
from pytorch_ie.core import PyTorchIEModel, TaskModule
from pytorch_ie.models import TransformerTokenClassificationModel
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
from pytorch_ie.taskmodules import * # noqa: F403
from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import Logger

Expand Down Expand Up @@ -92,49 +95,61 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)

# Init pytorch-ie dataset
log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")

# Init pytorch-ie taskmodule
log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")

# Init pytorch-ie dataset
log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
dataset: DatasetDict = hydra.utils.instantiate(
cfg.dataset,
_convert_="partial",
)

# auto-convert the dataset if the taskmodule specifies a document type
if taskmodule.document_type is not None:
if issubclass(dataset.document_type, taskmodule.document_type):
log.info(
f"the dataset is already of the document type that is specified by the taskmodule: "
f"{taskmodule.document_type}"
)
else:
log.info(
f"convert the dataset to the document type that is specified by the taskmodule: "
f"{taskmodule.document_type}"
)
dataset = dataset.to_document_type(taskmodule.document_type)
else:
log.warning(
"The taskmodule does not specify a document type. The dataset can not be automatically converted "
"to a document type."
)

# Init pytorch-ie datamodule
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
datamodule: PieDataModule = hydra.utils.instantiate(
cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial"
)
# Use the train dataset split to prepare the taskmodule
taskmodule.prepare(dataset["train"])
taskmodule.prepare(dataset[datamodule.train_split])

# Init the pytorch-ie model
log.info(f"Instantiating model <{cfg.model._target_}>")
# get additional model arguments
additional_model_kwargs: Dict[str, Any] = {}
model_cls = get_class(cfg.model["_target_"])
# NOTE: DEFINE THE additional_model_kwargs IF YOU WANT TO USE ANOTHER MODEL! SEE EXAMPLES BELOW.
if model_cls == TransformerTokenClassificationModel:
# NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE!
# SEE EXAMPLES BELOW.
if issubclass(model_cls, RequiresNumClasses):
additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
# elif model_cls == pytorch_ie.models.TransformerSpanClassificationModel:
# additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
# max_train_steps = cfg["trainer"]["max_epochs"] * datamodule.num_train
# additional_model_kwargs["t_total"] = int(
# max_train_steps / float(cfg["datamodule"]["batch_size"])
# )
# elif model_cls == pytorch_ie.models.TransformerTextClassificationModel:
# additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
# max_train_steps = cfg["trainer"]["max_epochs"] * datamodule.num_train
# additional_model_kwargs["t_total"] = int(
# max_train_steps / float(cfg["datamodule"]["batch_size"])
# )
# elif model_cls == pytorch_ie.models.TransformerSeq2SeqModel:
# pass
else:
raise Exception(
f"unknown model class: {model_cls.__name__}. Please adjust the train.py script for that class, i.e. "
f"define how to set additional_model_kwargs for your model."
)
if issubclass(model_cls, RequiresModelNameOrPath):
if "model_name_or_path" not in cfg.model:
raise Exception(
f"Please specify model_name_or_path in the model config for {model_cls.__name__}."
)
if isinstance(taskmodule, ChangesTokenizerVocabSize):
additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer)

# initialize the model
model: PyTorchIEModel = hydra.utils.instantiate(
cfg.model, _convert_="partial", **additional_model_kwargs
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/package_available.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import platform

import pkg_resources
from pytorch_lightning.utilities.xla_device import XLADeviceUtils
from lightning_fabric.accelerators import TPUAccelerator


def _package_available(package_name: str) -> bool:
Expand All @@ -12,7 +12,7 @@ def _package_available(package_name: str) -> bool:
return False


_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
_TPU_AVAILABLE = TPUAccelerator.is_available()

_IS_WINDOWS = platform.system() == "Windows"

Expand Down