Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update data pipeline to support post batch transforms #939

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions kauldron/data/py/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ class PyGrainPipeline(pipelines.Pipeline):
See doc:

Attributes:
transforms: A list of transformations to apply to the dataset. Each
transformation should be either a `grain.MapTransform` or a
transforms: A list of transformations to apply to the dataset before
batching. Each transformation should be either a `grain.MapTransform` or a
`grain.RandomMapTransform`.
post_batch_transforms: A list of transformations to apply after batching.
Each transformation should be either a `grain.MapTransform` or a
`grain.RandomMapTransform`.
num_epochs: Number of epoch. If missing, iterate indefinitely (number of
iteration is given by `cfg.num_training_steps`)
Expand All @@ -62,6 +65,9 @@ class PyGrainPipeline(pipelines.Pipeline):
transforms: tr_normalize.Transformations = dataclasses.field(
default_factory=tuple
)
post_batch_transforms: tr_normalize.Transformations = dataclasses.field(
default_factory=tuple
)

# Params only relevant for the root top-level dataset (when dataset mixture)
num_epochs: Optional[int] = None
Expand Down Expand Up @@ -113,6 +119,7 @@ def _root_ds(self) -> grain.IterDataset:
# batching.
if self.batch_size:
ds = ds.batch(self.batch_size, drop_remainder=self.batch_drop_remainder)
ds = transform_utils.apply_transforms(ds, self.post_batch_transforms)

# Distribute the execution across multiple worker processes.
num_workers = _get_num_workers(self.num_workers)
Expand Down
12 changes: 7 additions & 5 deletions kauldron/data/py/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

"""Utils for using Kauldron transforms with PyGrain."""

from typing import Any, Callable, Mapping
from typing import Any, Callable, Mapping, TypeVar

import grain.python as grain
from kauldron.data.transforms import abc as tr_abc
from kauldron.data.transforms import normalize as tr_normalize

_T = TypeVar("_T", grain.MapDataset, grain.IterDataset)


class PyGrainMapAdapter(tr_normalize.TransformAdapter, grain.MapTransform):
"""Adapter from `kd.data.MapTransform` to pygrain."""
Expand Down Expand Up @@ -60,8 +62,8 @@ def _adapt_for_pygrain(


def apply_transforms(
ds: grain.MapDataset, transforms: tr_normalize.Transformations
) -> grain.MapDataset:
ds: _T, transforms: tr_normalize.Transformations
) -> _T:
"""Apply the transformations to the dataset."""
if isinstance(transforms, Mapping):
transforms = transforms.values()
Expand All @@ -72,8 +74,8 @@ def apply_transforms(


def _apply_transform(
ds: grain.MapDataset, tr: grain.Transformation
) -> grain.MapDataset:
ds: _T, tr: grain.Transformation
) -> _T:
"""Apply a list of single transformation."""
match tr:
case grain.MapTransform():
Expand Down
16 changes: 14 additions & 2 deletions kauldron/data/tf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ class TFDataPipeline(pipelines.Pipeline, abc.ABC):
returns the `tf.data.Dataset` for the current process.

Attributes:
transforms: A list of `grain.Transformation` to apply to the dataset. Can be
a dict to allow easier CLI / sweep access (
transforms: A list of `grain.Transformation` to apply to the dataset before
batching. Can be a dict to allow easier CLI / sweep access (
`--cfg.train_ds.transforms.img_scale.in_vrange=(-1,1)`)
post_batch_transforms: A list of `grain.Transformation` to apply to the
dataset after batching. Can be a dict to allow easier CLI / sweep access
(`--cfg.train_ds.post_batch_transforms.img_scale.in_vrange=(-1,1)`)
tf_data_options: An optional tf.data.Options instance to be applied to the
dataset.
prefetch_size: Number of batches to prefetch for this dataset. Defaults to
Expand All @@ -79,6 +82,7 @@ class TFDataPipeline(pipelines.Pipeline, abc.ABC):
# TODO(epot): Users should also be able to specify drop_reminder or mask
batch_drop_remainder: bool = True
transforms: _Transforms = dataclasses.field(default_factory=tuple)
post_batch_transforms: _Transforms = dataclasses.field(default_factory=tuple)

# Those fields are only applied once at the top level
tf_data_options: Optional[tf.data.Options] = None
Expand Down Expand Up @@ -203,13 +207,21 @@ def _apply_transforms(self, ds: tf.data.Dataset) -> tf.data.Dataset:
transforms.extend(self.transforms.values())
else:
transforms.extend(self.transforms)

post_batch_transforms = []
if isinstance(self.post_batch_transforms, Mapping):
post_batch_transforms.extend(self.post_batch_transforms.values())
else:
post_batch_transforms.extend(self.post_batch_transforms)

if self.batch_size:
transforms.append(
grain.TfBatch(
batch_size=self.host_batch_size,
drop_remainder=self.batch_drop_remainder,
)
)
transforms.extend(post_batch_transforms)
ds = tr_utils.apply_transformations(ds, transforms)
return ds

Expand Down
Loading