Skip to content

Commit

Permalink
Feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
undfined committed Oct 30, 2024
1 parent 19db2a9 commit 1fdb995
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `SourceMixtureDataset` for composing a training mixture based on ratios of source datasets.
- Added `NumpyFSLDatasetMixture` for constructing a `NumpyDatasetBase` from a `SourceMixtureDataset`. Note this is only supported for FSL datasets.
- Added tests for `SourceMixture*` and `NumpyFSLDatasetMixture`.
- Added `DownstreamEvaluatorCallbackConfig` class for running in-loop downstream eval via [OLMo-in-loop-evals](https://github.com/allenai/OLMo-in-loop-evals).

### Changed
- Moved some types into `olmo_core.data.types` to avoid some circular dependencies.
- Added `DownstreamEvaluatorCallbackConfig` class for running in-loop downstream eval via [OLMo-in-loop-evals](https://github.com/allenai/OLMo-in-loop-evals).

### Removed

Expand Down
26 changes: 9 additions & 17 deletions src/olmo_core/data/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def file_sizes(self) -> Tuple[int, ...]:
The size, in bytes, of each numpy array.
"""
if self._array_file_sizes is None:
self._array_file_sizes = tuple(self.map(lambda item: get_file_size(item[0])))
self._array_file_sizes = tuple(self.map(lambda path, _: get_file_size(path)))
return self._array_file_sizes

@property
Expand Down Expand Up @@ -242,7 +242,7 @@ def _warmup_clients(self):

def map(
self,
func: Callable[[Tuple[PathOrStr, int]], T],
func: Callable[[PathOrStr, int], T],
*,
max_workers: Optional[int] = None,
method: Literal["threads", "processes"] = "threads",
Expand All @@ -251,7 +251,7 @@ def map(
"""
Call a function on each path in the dataset, returning a list of the results, in order.
:param func: The function to map to the paths.
:param func: The function to map to the paths and their indices.
:param max_workers: The number of workers threads/processes. Set to 0 to execute synchronously
in the main thread/process.
:param method: Whether to use multi-threading or multi-processing.
Expand All @@ -261,7 +261,7 @@ def map(
paths = _paths or self.paths

if max_workers == 0:
return [func((path, idx)) for idx, path in enumerate(paths)]
return [func(path, idx) for idx, path in enumerate(paths)]

executor_class: Union[
Type[concurrent.futures.ThreadPoolExecutor],
Expand All @@ -276,7 +276,7 @@ def map(
raise ValueError(method)

with executor_class(max_workers=max_workers) as executor:
futures = [executor.submit(func, (path, idx)) for idx, path in enumerate(paths)]
futures = [executor.submit(func, path, idx) for idx, path in enumerate(paths)]

return [future.result() for future in futures]

Expand Down Expand Up @@ -484,9 +484,8 @@ def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor:
)

def _get_file_size_and_length(
self, item: Tuple[PathOrStr, int], dtype: Optional[NumpyUIntTypes] = None
self, path: PathOrStr, idx: int, dtype: Optional[NumpyUIntTypes] = None
) -> Tuple[int, int]:
path, _ = item
dtype = dtype or self.dtype
item_size = dtype(0).itemsize
file_size = get_file_size(path)
Expand Down Expand Up @@ -525,7 +524,6 @@ def __init__(
include_instance_metadata: Optional[bool] = None,
generate_doc_lengths: bool = False,
max_target_sequence_length: Optional[int] = None,
bust_index_cache: bool = False,
):
if max_target_sequence_length is not None and (
max_target_sequence_length < sequence_length
Expand Down Expand Up @@ -565,7 +563,6 @@ def __init__(
self._lengths_dtype: Optional[NumpyUIntTypes] = None
self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None
self._path_offset_index = path_offset_index
self._bust_index_cache = bust_index_cache
self._seed = seed

def prepare(self):
Expand Down Expand Up @@ -629,11 +626,11 @@ def _write_document_indices(self):
)

def _get_file_size_and_length(
self, item: Tuple[PathOrStr, int], dtype: Optional[NumpyUIntTypes] = None
self, path: PathOrStr, idx: int, dtype: Optional[NumpyUIntTypes] = None
) -> Tuple[int, int]:
dtype = dtype or self.dtype
item_size = dtype(0).itemsize
file_size = self._get_size_from_offset_index(item)
file_size = self._get_size_from_offset_index((path, idx))
if (
self.max_target_sequence_length is None
or self.max_target_sequence_length == self.sequence_length
Expand Down Expand Up @@ -692,7 +689,7 @@ def offsets(self) -> Tuple[Tuple[int, int], ...]:
if self._array_instance_offsets is None:
item_size = self.indices_dtype(0).itemsize
num_instances_per_path = self.map(
lambda item: get_file_size(self._get_instance_indices_path(item[0]))
lambda path, _: get_file_size(self._get_instance_indices_path(path))
// (item_size * 2)
)
array_instance_offsets = []
Expand Down Expand Up @@ -1485,10 +1482,6 @@ class NumpyDatasetConfig(Config):
"""
The type of dataset.
"""
bust_index_cache: bool = False
"""
Whether or not to bust the index cache.
"""
source_mixture_config: Optional[SourceMixtureDatasetConfig] = None
"""
The source mixture dataset config.
Expand Down Expand Up @@ -1707,7 +1700,6 @@ def build(self) -> NumpyDatasetBase:
include_instance_metadata=self.include_instance_metadata,
generate_doc_lengths=self.generate_doc_lengths,
path_offset_index=mixture.to_index(),
bust_index_cache=self.bust_index_cache,
)
else:
dataset = NumpyFSLDataset(
Expand Down
63 changes: 62 additions & 1 deletion src/olmo_core/data/source_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,31 @@

@dataclass
class SourceMixtureConfig(Config):
"""
A configuration class for building a source mixture.
"""

source_name: str
"""
The name of the source.
"""
target_ratio: float
"""
The target ratio of the source in the mixture.
"""
paths: List[PathOrStr]
# 1.0 will result in a maximum of 1 repitition of the source data per epoch
"""
A list of paths to the source data.
"""
max_repetition_ratio: float = 1.0
"""
The maximum ratio of repetitions of the source data to include in the mixture.
This can be used to upsample the source data by setting the repetition ratio > 1.
"""
max_source_fraction: float = 1.0
"""
The maximum ratio of the source data to include in the mixture.
"""

def validate(self):
if self.target_ratio:
Expand All @@ -43,6 +62,9 @@ def validate(self):
if self.max_source_fraction < self.target_ratio:
raise OLMoConfigurationError("max_source_fraction must be >= target_ratio")

if self.max_repetition_ratio < 1:
raise OLMoConfigurationError("max_repetition_ratio must be >= 1")

if not self.paths:
raise OLMoConfigurationError("paths must not be empty")

Expand All @@ -57,8 +79,17 @@ class SourceTokenDetails:
"""

config: SourceMixtureConfig
"""
The configuration object associated with the source.
"""
population: int
"""
The total number of tokens available for the source.
"""
num_selected: int
"""
The number of tokens to select for the source.
"""

def for_table(self, max_tokens: int) -> Dict:
return {
Expand All @@ -82,7 +113,13 @@ class SourcePathTokens:
@dataclass
class SourceMixtureOutcome:
name: str
"""
The name of the source.
"""
path_tokens: List[SourcePathTokens]
"""
A list of paths and the associated token counts.
"""


@dataclass
Expand All @@ -92,7 +129,13 @@ class SourceMixtureDataset:
"""

seed: int
"""
The seed used to generate the dataset.
"""
sources: List[SourceMixtureOutcome]
"""
A list of sources and the associated paths and token counts.
"""

def to_index(self) -> Dict[Tuple[str, int], int]:
"""
Expand Down Expand Up @@ -122,11 +165,29 @@ class SourceMixtureDatasetConfig(Config):
"""

max_tokens: int
"""
The maximum number of tokens to include in the dataset.
"""
source_configs: List[SourceMixtureConfig]
"""
A list of source configurations.
"""
sequence_length: int
"""
The instance sequence length of the dataset.
"""
dtype: NumpyDatasetDType
"""
The data type of the dataset.
"""
processes: int = 1
"""
The number of processes to use for counting tokens in parallel.
"""
seed: int = 42
"""
The seed used to generate the dataset.
"""

def validate(self):
if self.max_tokens <= 0:
Expand Down
1 change: 0 additions & 1 deletion src/test/data/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def get_fsl_mixture(
source_mixture_config=mixture_config,
sequence_length=sequence_length,
tokenizer=tokenizer,
bust_index_cache=True,
include_instance_metadata=False,
).build()
ds.prepare()
Expand Down
2 changes: 0 additions & 2 deletions src/test/data/numpy_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def test_numpy_fsl_mixture_dataset(tmp_path: Path):
source_mixture_config=mixture_config,
sequence_length=sequence_length,
tokenizer=tokenizer,
bust_index_cache=True,
include_instance_metadata=False,
).build()
ds.prepare()
Expand Down Expand Up @@ -169,7 +168,6 @@ def test_numpy_fsl_mixture_dataset_with_repetition(tmp_path: Path):
source_mixture_config=mixture_config,
sequence_length=sequence_length,
tokenizer=tokenizer,
bust_index_cache=True,
include_instance_metadata=False,
).build()
ds.prepare()
Expand Down

0 comments on commit 1fdb995

Please sign in to comment.