From b0ce7afc57fdd714480e6d56557e5bc8272f8e1b Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Tue, 12 Apr 2022 22:12:23 +0200 Subject: [PATCH] Refactor datamodule (#121) * add DataModule and simple DatasetDict to data package * simplify Datamodule and add parameters train_split, val_split, and test_split; prepare_data_split defaults to value of train_split * prepare taskmodule only when stage="fit" or None * fix typing * remove DataModule from data package to prevent circular import * add note regarding cyclic import * rename prepare_data_split to prepare_split and task_module to taskmodule to be consistent; use DatasetDict * pass all remaining parameters to the dataloaders * to the dark side (make black happy) * fix types * skip creating encoded splits if respective dataset split is not available * fix creation of random train val split * blackify * create and use PIEDatasetDict instead of DatasetDict * revert data.__init__ to state of main branch to decrease noise * use same name for validation split as hf datasets does * fix exception message * allow relative split sizes for entries in random_train_val_split * remove functionality to create a random train val split Co-authored-by: Arne Binder --- src/pytorch_ie/data/datamodules/datamodule.py | 117 ++++++++---------- 1 file changed, 50 insertions(+), 67 deletions(-) diff --git a/src/pytorch_ie/data/datamodules/datamodule.py b/src/pytorch_ie/data/datamodules/datamodule.py index c00401c4..ff27723a 100644 --- a/src/pytorch_ie/data/datamodules/datamodule.py +++ b/src/pytorch_ie/data/datamodules/datamodule.py @@ -1,10 +1,10 @@ -from typing import Any, Dict, Generic, List, Optional, Tuple +from typing import Any, Dict, Generic, List, Optional, Tuple, Union from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split from torch.utils.data.dataset import Dataset -from pytorch_ie import Document +from pytorch_ie.data.datasets import PIEDatasetDict from pytorch_ie.taskmodules.taskmodule import ( InputEncoding, TargetEncoding, @@ -20,7 +20,7 @@ class TaskEncodingDataset( def __init__(self, encodings: List[TaskEncoding[InputEncoding, TargetEncoding]]): self._encodings = encodings - def __getitem__(self, index) -> TaskEncoding[InputEncoding, TargetEncoding]: + def __getitem__(self, index): return self._encodings[index] def __len__(self): @@ -47,96 +47,79 @@ class DataModule(LightningDataModule, Generic[InputEncoding, TargetEncoding]): def __init__( self, - task_module: TaskModule[InputEncoding, TargetEncoding, Any, Any, Any], - dataset: Dict[str, List[Document]], - random_train_val_split: Optional[Tuple[int, int]] = None, - batch_size: int = 32, - num_workers: int = 0, - pin_memory: bool = False, + taskmodule: TaskModule[InputEncoding, TargetEncoding, Any, Any, Any], + dataset: PIEDatasetDict, data_config_path: Optional[str] = None, - prepare_data_split: str = "train", - **kwargs, + train_split: Optional[str] = "train", + val_split: Optional[str] = "validation", + test_split: Optional[str] = "test", + prepare_split: Optional[str] = None, + **dataloader_kwargs, ): - super().__init__(**kwargs) + super().__init__() - self.task_module = task_module - self.batch_size = batch_size - self.num_workers = num_workers - self.pin_memory = pin_memory + self.taskmodule = taskmodule self.config_path = data_config_path self.dataset = dataset - self.prepare_data_split = prepare_data_split - self.random_train_val_split = random_train_val_split + self.train_split = train_split + self.val_split = val_split + self.test_split = test_split + # per default, use train data to prepare the taskmodule + self.prepare_split = prepare_split or self.train_split + self.dataloader_kwargs = dataloader_kwargs - self.data_train: Optional[TaskEncodingDataset[InputEncoding, TargetEncoding]] = None - self.data_val: Optional[TaskEncodingDataset[InputEncoding, TargetEncoding]] = None - self.data_test: Optional[TaskEncodingDataset[InputEncoding, TargetEncoding]] = None + self._data: Dict[str, TaskEncodingDataset[InputEncoding, TargetEncoding]] = {} @property def num_train(self) -> int: - if self.data_train is None: + if self.train_split is None: + raise ValueError("no train_split assigned") + data_train = self._data.get(self.train_split, None) + if data_train is None: raise ValueError("can not get train size if setup() was not yet called") - return len(self.data_train) + return len(data_train) def setup(self, stage: Optional[str] = None, **kwargs): - for split, data in self.dataset.items(): - - if split == self.prepare_data_split: - self.task_module.prepare(data) - - if split == "train": - self.data_train = TaskEncodingDataset( - self.task_module.encode(data, encode_target=True) - ) - elif split == "val": - self.data_val = TaskEncodingDataset( - self.task_module.encode(data, encode_target=True) - ) - elif split == "test": - self.data_test = TaskEncodingDataset( - self.task_module.encode(data, encode_target=True) - ) - else: - raise ValueError( - f'Unknowns split identifier: "{split}". Use one of "train", "val", or "test".' - ) - - if self.random_train_val_split is not None: - assert ( - self.data_train is not None - ), "data_train has to be set to create random train dev splits from it" - # type checking is broken for random_split, so we ignore it - self.data_train, self.data_val = random_split( # type: ignore - self.data_train, self.random_train_val_split + if stage == "fit" or stage is None: + if self.prepare_split is None: + raise ValueError(f"prepare_data_split is required to prepare the taskmodule") + self.taskmodule.prepare(self.dataset[self.prepare_split]) + + for split in [self.train_split, self.val_split, self.test_split]: + if split is None or split not in self.dataset: + continue + self._data[split] = TaskEncodingDataset( + self.taskmodule.encode(self.dataset[split], encode_target=True) ) + def data_split( + self, split: Optional[str] = None + ) -> TaskEncodingDataset[InputEncoding, TargetEncoding]: + if split is None or split not in self._data: + raise ValueError(f"data for split={split} not available") + return self._data[split] + def train_dataloader(self): return DataLoader( - dataset=self.data_train, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=self.task_module.collate, + dataset=self.data_split(self.train_split), + collate_fn=self.taskmodule.collate, shuffle=True, + **self.dataloader_kwargs, ) def val_dataloader(self): return DataLoader( - dataset=self.data_val, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=self.task_module.collate, + dataset=self.data_split(self.val_split), + collate_fn=self.taskmodule.collate, shuffle=False, + **self.dataloader_kwargs, ) def test_dataloader(self): return DataLoader( - dataset=self.data_test, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=self.task_module.collate, + dataset=self.data_split(self.test_split), + collate_fn=self.taskmodule.collate, shuffle=False, + **self.dataloader_kwargs, )