Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data mixes #29

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added support for unsharding model state into `safetensors` format with `olmo_core.distributed.checkpoint.unshard_checkpoint(..., use_safetensors=True)`.
- Added `data.TokenizerConfig` config class and `data.TokenizerNames` enumeration.
- Added `data.TokenizerConfig` config class and `data.TokenizerName` enumeration.
- Added data mixes with `data.DataMix` API.

## [v1.0.1](https://github.com/allenai/OLMo-core/releases/tag/v1.0.1) - 2024-08-26

Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include src/olmo_core/data/mixes/*.txt
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"rich",
"omegaconf",
"safetensors",
"importlib_resources",
]

[project.urls]
Expand Down Expand Up @@ -65,7 +66,7 @@ all = [
include-package-data = true

[tool.setuptools.package-data]
olmo_core = ["py.typed"]
olmo_core = ["py.typed", "*.txt"]

[tool.setuptools.dynamic]
version = { attr = "olmo_core.version.VERSION" }
Expand Down
6 changes: 4 additions & 2 deletions src/olmo_core/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from .collator import DataCollator, PaddingDirection
from .iterable_dataset import IterableDataset
from .memmap_dataset import MemMapDataset, MemMapDatasetConfig, MemMapDType
from .tokenizer import TokenizerConfig, TokenizerNames
from .mixes import DataMix
from .tokenizer import TokenizerConfig, TokenizerName

__all__ = [
"MemMapDatasetConfig",
"MemMapDataset",
"MemMapDType",
"TokenizerConfig",
"TokenizerNames",
"TokenizerName",
"DataMix",
"DataCollator",
"PaddingDirection",
"IterableDataset",
Expand Down
49 changes: 43 additions & 6 deletions src/olmo_core/data/memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import torch
from torch.utils.data import Dataset

from olmo_core.exceptions import OLMoEnvironmentError
from olmo_core.exceptions import OLMoConfigurationError, OLMoEnvironmentError

from ..aliases import PathOrStr
from ..config import Config, StrEnum
from ..io import _get_s3_client, file_size, get_bytes_range
from ..utils import get_document_lengths
from .mixes import DataMix
from .tokenizer import TokenizerConfig

__all__ = ["MemMapDatasetConfig", "MemMapDataset"]
Expand All @@ -41,9 +42,10 @@ class MemMapDatasetConfig(Config):
A config class for easily building :class:`MemMapDataset` classes.
"""

paths: List[str]
sequence_length: int
tokenizer: TokenizerConfig
paths: Optional[List[str]] = None
mix: Optional[DataMix] = None
memmap_dtype: Optional[MemMapDType] = None
metadata: Optional[List[Dict[str, Any]]] = None
include_instance_metadata: bool = True
Expand All @@ -62,10 +64,28 @@ def glob(cls, *glob_paths: str, **kwargs) -> "MemMapDatasetConfig":
If any of the globs don't expand to any matches a :class:`FileNotFoundError`
error is raised

:returns: A new config.
:returns: A new dataset config.
"""
return cls(paths=list(glob_paths), expand_glob=True, **kwargs)

@classmethod
def from_data_mix(
cls, mix: DataMix, *, tokenizer: TokenizerConfig, **kwargs
) -> "MemMapDatasetConfig":
"""
Initialize a dataset config from an official data mix.

:param mix: The data mix.
:param tokenizer: The tokenizer config.

:returns: A new dataset config.
"""
if tokenizer.identifier is None:
raise OLMoConfigurationError(
"Missing tokenizer identifier required to construct data mix"
)
return cls(mix=mix, tokenizer=tokenizer, **kwargs)

def get_memmap_dtype(
self,
) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]:
Expand All @@ -85,12 +105,18 @@ def get_memmap_dtype(

raise ValueError("vocab size too big!")

def build(self) -> MemMapDataset:
def build(self, mix_base_dir: Optional[str] = None) -> MemMapDataset:
"""
Construct the corresponding :class:`MemMapDataset`.

:param mix_base_dir: The base directory for the :data:`mix`, e.g. "s3://ai2-llm".
Required if initializing from a data mix.
"""
if (self.paths is None) == (self.mix is None):
raise OLMoConfigurationError("Exactly one of 'paths' or 'mix' is required")

paths: List[str] = []
if self.expand_glob:
if self.paths and self.expand_glob:
from glob import glob

for glob_path in self.paths:
Expand All @@ -101,8 +127,19 @@ def build(self) -> MemMapDataset:
for path in matches:
log.info(f" - '{path}'")
paths.extend(matches)
else:
elif self.paths:
paths = self.paths
else:
assert self.mix is not None
if mix_base_dir is None:
raise OLMoConfigurationError(
"'mix_base_dir' is required to build a dataset from a mix"
)
if self.tokenizer.identifier is None:
raise OLMoConfigurationError(
"Missing tokenizer identifier required to construct data mix"
)
paths = self.mix.build(mix_base_dir, self.tokenizer.identifier)

dataset = MemMapDataset(
*paths,
Expand Down
Loading
Loading