diff --git a/kauldron/data/py/base.py b/kauldron/data/py/base.py index 6bd30a99..db817237 100644 --- a/kauldron/data/py/base.py +++ b/kauldron/data/py/base.py @@ -40,8 +40,11 @@ class PyGrainPipeline(pipelines.Pipeline): See doc: Attributes: - transforms: A list of transformations to apply to the dataset. Each - transformation should be either a `grain.MapTransform` or a + transforms: A list of transformations to apply to the dataset before + batching. Each transformation should be either a `grain.MapTransform` or a + `grain.RandomMapTransform`. + post_batch_transforms: A list of transformations to apply after batching. + Each transformation should be either a `grain.MapTransform` or a `grain.RandomMapTransform`. num_epochs: Number of epoch. If missing, iterate indefinitely (number of iteration is given by `cfg.num_training_steps`) @@ -62,6 +65,9 @@ class PyGrainPipeline(pipelines.Pipeline): transforms: tr_normalize.Transformations = dataclasses.field( default_factory=tuple ) + post_batch_transforms: tr_normalize.Transformations = dataclasses.field( + default_factory=tuple + ) # Params only relevant for the root top-level dataset (when dataset mixture) num_epochs: Optional[int] = None @@ -113,6 +119,7 @@ def _root_ds(self) -> grain.IterDataset: # batching. if self.batch_size: ds = ds.batch(self.batch_size, drop_remainder=self.batch_drop_remainder) + ds = transform_utils.apply_transforms(ds, self.post_batch_transforms) # Distribute the execution across multiple worker processes. num_workers = _get_num_workers(self.num_workers) diff --git a/kauldron/data/py/transform_utils.py b/kauldron/data/py/transform_utils.py index 19548d3e..9712a535 100644 --- a/kauldron/data/py/transform_utils.py +++ b/kauldron/data/py/transform_utils.py @@ -14,12 +14,14 @@ """Utils for using Kauldron transforms with PyGrain.""" -from typing import Any, Callable, Mapping +from typing import Any, Callable, Mapping, TypeVar import grain.python as grain from kauldron.data.transforms import abc as tr_abc from kauldron.data.transforms import normalize as tr_normalize +_T = TypeVar("_T", grain.MapDataset, grain.IterDataset) + class PyGrainMapAdapter(tr_normalize.TransformAdapter, grain.MapTransform): """Adapter from `kd.data.MapTransform` to pygrain.""" @@ -60,8 +62,8 @@ def _adapt_for_pygrain( def apply_transforms( - ds: grain.MapDataset, transforms: tr_normalize.Transformations -) -> grain.MapDataset: + ds: _T, transforms: tr_normalize.Transformations +) -> _T: """Apply the transformations to the dataset.""" if isinstance(transforms, Mapping): transforms = transforms.values() @@ -72,8 +74,8 @@ def apply_transforms( def _apply_transform( - ds: grain.MapDataset, tr: grain.Transformation -) -> grain.MapDataset: + ds: _T, tr: grain.Transformation +) -> _T: """Apply a list of single transformation.""" match tr: case grain.MapTransform(): diff --git a/kauldron/data/tf/base.py b/kauldron/data/tf/base.py index dee9f017..731f7c11 100644 --- a/kauldron/data/tf/base.py +++ b/kauldron/data/tf/base.py @@ -58,9 +58,12 @@ class TFDataPipeline(pipelines.Pipeline, abc.ABC): returns the `tf.data.Dataset` for the current process. Attributes: - transforms: A list of `grain.Transformation` to apply to the dataset. Can be - a dict to allow easier CLI / sweep access ( + transforms: A list of `grain.Transformation` to apply to the dataset before + batching. Can be a dict to allow easier CLI / sweep access ( `--cfg.train_ds.transforms.img_scale.in_vrange=(-1,1)`) + post_batch_transforms: A list of `grain.Transformation` to apply to the + dataset after batching. Can be a dict to allow easier CLI / sweep access + (`--cfg.train_ds.post_batch_transforms.img_scale.in_vrange=(-1,1)`) tf_data_options: An optional tf.data.Options instance to be applied to the dataset. prefetch_size: Number of batches to prefetch for this dataset. Defaults to @@ -79,6 +82,7 @@ class TFDataPipeline(pipelines.Pipeline, abc.ABC): # TODO(epot): Users should also be able to specify drop_reminder or mask batch_drop_remainder: bool = True transforms: _Transforms = dataclasses.field(default_factory=tuple) + post_batch_transforms: _Transforms = dataclasses.field(default_factory=tuple) # Those fields are only applied once at the top level tf_data_options: Optional[tf.data.Options] = None @@ -203,6 +207,13 @@ def _apply_transforms(self, ds: tf.data.Dataset) -> tf.data.Dataset: transforms.extend(self.transforms.values()) else: transforms.extend(self.transforms) + + post_batch_transforms = [] + if isinstance(self.post_batch_transforms, Mapping): + post_batch_transforms.extend(self.post_batch_transforms.values()) + else: + post_batch_transforms.extend(self.post_batch_transforms) + if self.batch_size: transforms.append( grain.TfBatch( @@ -210,6 +221,7 @@ def _apply_transforms(self, ds: tf.data.Dataset) -> tf.data.Dataset: drop_remainder=self.batch_drop_remainder, ) ) + transforms.extend(post_batch_transforms) ds = tr_utils.apply_transformations(ds, transforms) return ds