diff --git a/configs/dataset/_convert_documents.yaml b/configs/dataset/_convert_documents.yaml index c87306a5..e3a0b817 100644 --- a/configs/dataset/_convert_documents.yaml +++ b/configs/dataset/_convert_documents.yaml @@ -1,2 +1,3 @@ convert_documents: - _processor_: pytorch_ie.DatasetDict.map + _processor_: pytorch_ie.DatasetDict.to_document_type + document_type: ??? diff --git a/requirements.txt b/requirements.txt index 718ee51e..cdab9f86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # --------- pytorch-ie --------- # -pytorch-ie>=0.19.0,<1.0.0 +pytorch-ie>=0.23.0,<1.0.0 # pie-utils provides some useful helper methods for pytorch-ie, # e.g. document processors or span utils (convert span annotations # to sequence encodings such as BIO, IO or BIOUL, and back). diff --git a/src/train.py b/src/train.py index 55363701..4411264c 100644 --- a/src/train.py +++ b/src/train.py @@ -41,7 +41,10 @@ from omegaconf import DictConfig from pytorch_ie import DatasetDict from pytorch_ie.core import PyTorchIEModel, TaskModule -from pytorch_ie.models import TransformerTokenClassificationModel +from pytorch_ie.models import * # noqa: F403 +from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses +from pytorch_ie.taskmodules import * # noqa: F403 +from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import Logger @@ -92,49 +95,61 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: if cfg.get("seed"): pl.seed_everything(cfg.seed, workers=True) - # Init pytorch-ie dataset - log.info(f"Instantiating dataset <{cfg.dataset._target_}>") - dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial") - # Init pytorch-ie taskmodule log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>") taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial") + # Init pytorch-ie dataset + log.info(f"Instantiating dataset <{cfg.dataset._target_}>") + dataset: DatasetDict = hydra.utils.instantiate( + cfg.dataset, + _convert_="partial", + ) + + # auto-convert the dataset if the taskmodule specifies a document type + if taskmodule.document_type is not None: + if issubclass(dataset.document_type, taskmodule.document_type): + log.info( + f"the dataset is already of the document type that is specified by the taskmodule: " + f"{taskmodule.document_type}" + ) + else: + log.info( + f"convert the dataset to the document type that is specified by the taskmodule: " + f"{taskmodule.document_type}" + ) + dataset = dataset.to_document_type(taskmodule.document_type) + else: + log.warning( + "The taskmodule does not specify a document type. The dataset can not be automatically converted " + "to a document type." + ) + # Init pytorch-ie datamodule log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") datamodule: PieDataModule = hydra.utils.instantiate( cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial" ) # Use the train dataset split to prepare the taskmodule - taskmodule.prepare(dataset["train"]) + taskmodule.prepare(dataset[datamodule.train_split]) # Init the pytorch-ie model log.info(f"Instantiating model <{cfg.model._target_}>") # get additional model arguments additional_model_kwargs: Dict[str, Any] = {} model_cls = get_class(cfg.model["_target_"]) - # NOTE: DEFINE THE additional_model_kwargs IF YOU WANT TO USE ANOTHER MODEL! SEE EXAMPLES BELOW. - if model_cls == TransformerTokenClassificationModel: + # NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE! + # SEE EXAMPLES BELOW. + if issubclass(model_cls, RequiresNumClasses): additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id) - # elif model_cls == pytorch_ie.models.TransformerSpanClassificationModel: - # additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id) - # max_train_steps = cfg["trainer"]["max_epochs"] * datamodule.num_train - # additional_model_kwargs["t_total"] = int( - # max_train_steps / float(cfg["datamodule"]["batch_size"]) - # ) - # elif model_cls == pytorch_ie.models.TransformerTextClassificationModel: - # additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id) - # max_train_steps = cfg["trainer"]["max_epochs"] * datamodule.num_train - # additional_model_kwargs["t_total"] = int( - # max_train_steps / float(cfg["datamodule"]["batch_size"]) - # ) - # elif model_cls == pytorch_ie.models.TransformerSeq2SeqModel: - # pass - else: - raise Exception( - f"unknown model class: {model_cls.__name__}. Please adjust the train.py script for that class, i.e. " - f"define how to set additional_model_kwargs for your model." - ) + if issubclass(model_cls, RequiresModelNameOrPath): + if "model_name_or_path" not in cfg.model: + raise Exception( + f"Please specify model_name_or_path in the model config for {model_cls.__name__}." + ) + if isinstance(taskmodule, ChangesTokenizerVocabSize): + additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer) + # initialize the model model: PyTorchIEModel = hydra.utils.instantiate( cfg.model, _convert_="partial", **additional_model_kwargs diff --git a/tests/helpers/package_available.py b/tests/helpers/package_available.py index 15630f19..7190eb1e 100644 --- a/tests/helpers/package_available.py +++ b/tests/helpers/package_available.py @@ -1,7 +1,7 @@ import platform import pkg_resources -from pytorch_lightning.utilities.xla_device import XLADeviceUtils +from lightning_fabric.accelerators import TPUAccelerator def _package_available(package_name: str) -> bool: @@ -12,7 +12,7 @@ def _package_available(package_name: str) -> bool: return False -_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() +_TPU_AVAILABLE = TPUAccelerator.is_available() _IS_WINDOWS = platform.system() == "Windows"