Skip to content

Commit

Permalink
Refactor datamodule (#121)
Browse files Browse the repository at this point in the history
* add DataModule and simple DatasetDict to data package

* simplify Datamodule and add parameters train_split, val_split, and test_split; prepare_data_split defaults to value of train_split

* prepare taskmodule only when stage="fit" or None

* fix typing

* remove DataModule from data package to prevent circular import

* add note regarding cyclic import

* rename prepare_data_split to prepare_split and task_module to taskmodule to be consistent; use DatasetDict

* pass all remaining parameters to the dataloaders

* to the dark side (make black happy)

* fix types

* skip creating encoded splits if respective dataset split is not available

* fix creation of random train val split

* blackify

* create and use PIEDatasetDict instead of DatasetDict

* revert data.__init__ to state of main branch to decrease noise

* use same name for validation split as hf datasets does

* fix exception message

* allow relative split sizes for entries in random_train_val_split

* remove functionality to create a random train val split

Co-authored-by: Arne Binder <[email protected]>
  • Loading branch information
ArneBinder and ArneBinder authored Apr 12, 2022
1 parent ed634ac commit b0ce7af
Showing 1 changed file with 50 additions and 67 deletions.
117 changes: 50 additions & 67 deletions src/pytorch_ie/data/datamodules/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Generic, List, Optional, Tuple
from typing import Any, Dict, Generic, List, Optional, Tuple, Union

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset

from pytorch_ie import Document
from pytorch_ie.data.datasets import PIEDatasetDict
from pytorch_ie.taskmodules.taskmodule import (
InputEncoding,
TargetEncoding,
Expand All @@ -20,7 +20,7 @@ class TaskEncodingDataset(
def __init__(self, encodings: List[TaskEncoding[InputEncoding, TargetEncoding]]):
self._encodings = encodings

def __getitem__(self, index) -> TaskEncoding[InputEncoding, TargetEncoding]:
def __getitem__(self, index):
return self._encodings[index]

def __len__(self):
Expand All @@ -47,96 +47,79 @@ class DataModule(LightningDataModule, Generic[InputEncoding, TargetEncoding]):

def __init__(
self,
task_module: TaskModule[InputEncoding, TargetEncoding, Any, Any, Any],
dataset: Dict[str, List[Document]],
random_train_val_split: Optional[Tuple[int, int]] = None,
batch_size: int = 32,
num_workers: int = 0,
pin_memory: bool = False,
taskmodule: TaskModule[InputEncoding, TargetEncoding, Any, Any, Any],
dataset: PIEDatasetDict,
data_config_path: Optional[str] = None,
prepare_data_split: str = "train",
**kwargs,
train_split: Optional[str] = "train",
val_split: Optional[str] = "validation",
test_split: Optional[str] = "test",
prepare_split: Optional[str] = None,
**dataloader_kwargs,
):
super().__init__(**kwargs)
super().__init__()

self.task_module = task_module
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.taskmodule = taskmodule
self.config_path = data_config_path
self.dataset = dataset
self.prepare_data_split = prepare_data_split
self.random_train_val_split = random_train_val_split
self.train_split = train_split
self.val_split = val_split
self.test_split = test_split
# per default, use train data to prepare the taskmodule
self.prepare_split = prepare_split or self.train_split
self.dataloader_kwargs = dataloader_kwargs

self.data_train: Optional[TaskEncodingDataset[InputEncoding, TargetEncoding]] = None
self.data_val: Optional[TaskEncodingDataset[InputEncoding, TargetEncoding]] = None
self.data_test: Optional[TaskEncodingDataset[InputEncoding, TargetEncoding]] = None
self._data: Dict[str, TaskEncodingDataset[InputEncoding, TargetEncoding]] = {}

@property
def num_train(self) -> int:
if self.data_train is None:
if self.train_split is None:
raise ValueError("no train_split assigned")
data_train = self._data.get(self.train_split, None)
if data_train is None:
raise ValueError("can not get train size if setup() was not yet called")
return len(self.data_train)
return len(data_train)

def setup(self, stage: Optional[str] = None, **kwargs):

for split, data in self.dataset.items():

if split == self.prepare_data_split:
self.task_module.prepare(data)

if split == "train":
self.data_train = TaskEncodingDataset(
self.task_module.encode(data, encode_target=True)
)
elif split == "val":
self.data_val = TaskEncodingDataset(
self.task_module.encode(data, encode_target=True)
)
elif split == "test":
self.data_test = TaskEncodingDataset(
self.task_module.encode(data, encode_target=True)
)
else:
raise ValueError(
f'Unknowns split identifier: "{split}". Use one of "train", "val", or "test".'
)

if self.random_train_val_split is not None:
assert (
self.data_train is not None
), "data_train has to be set to create random train dev splits from it"
# type checking is broken for random_split, so we ignore it
self.data_train, self.data_val = random_split( # type: ignore
self.data_train, self.random_train_val_split
if stage == "fit" or stage is None:
if self.prepare_split is None:
raise ValueError(f"prepare_data_split is required to prepare the taskmodule")
self.taskmodule.prepare(self.dataset[self.prepare_split])

for split in [self.train_split, self.val_split, self.test_split]:
if split is None or split not in self.dataset:
continue
self._data[split] = TaskEncodingDataset(
self.taskmodule.encode(self.dataset[split], encode_target=True)
)

def data_split(
self, split: Optional[str] = None
) -> TaskEncodingDataset[InputEncoding, TargetEncoding]:
if split is None or split not in self._data:
raise ValueError(f"data for split={split} not available")
return self._data[split]

def train_dataloader(self):
return DataLoader(
dataset=self.data_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self.task_module.collate,
dataset=self.data_split(self.train_split),
collate_fn=self.taskmodule.collate,
shuffle=True,
**self.dataloader_kwargs,
)

def val_dataloader(self):
return DataLoader(
dataset=self.data_val,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self.task_module.collate,
dataset=self.data_split(self.val_split),
collate_fn=self.taskmodule.collate,
shuffle=False,
**self.dataloader_kwargs,
)

def test_dataloader(self):
return DataLoader(
dataset=self.data_test,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=self.task_module.collate,
dataset=self.data_split(self.test_split),
collate_fn=self.taskmodule.collate,
shuffle=False,
**self.dataloader_kwargs,
)

0 comments on commit b0ce7af

Please sign in to comment.