Skip to content

Commit

Permalink
Merge branch 'main' into elastic-resumption-regtest
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Sep 18, 2023
2 parents 27dca35 + c369a68 commit 5445435
Show file tree
Hide file tree
Showing 20 changed files with 1,674 additions and 47 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
save_folder: str,
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'fp32',
precision: str = 'float32',
overwrite: bool = False,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def build_finetuning_dataloader(cfg: DictConfig,
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18),
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
)

collate_fn, dataloader_batch_size = _build_collate_fn(
Expand Down
104 changes: 73 additions & 31 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,44 +71,76 @@ class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Args:
local (str): Local dataset directory where shards are cached by split.
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
remote (str, optional): Download shards from this remote path or directory. If None, this
rank and worker's partition of the dataset must all exist locally. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
keep_zip (bool, optional): Whether to keep or delete the compressed file when
decompressing downloaded shards. If set to None, keep if remote is local. Defaults to
``None``.
local (str): Local dataset directory where shards are cached by split.
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
If ``None``, defaults to the number of nodes of the initial run. Defaults to 128.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
`False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s) may
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
to disable shard eviction. Supports integer bytes as well as string human-readable
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
"""

def __init__(self,
local: str,
tokenizer: PreTrainedTokenizerBase,
local: str,
remote: Optional[str] = None,
split: Optional[str] = None,
shuffle: bool = False,
predownload: Optional[int] = 100_000,
keep_zip: bool = False,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = 128,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
partition_algo: str = 'orig',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1b',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
**kwargs: Any):

if len(kwargs) > 0:
Expand All @@ -125,18 +157,28 @@ def __init__(self,
)

# Build Dataset
super().__init__(local=local,
remote=remote,
split=split,
shuffle=shuffle,
predownload=predownload,
keep_zip=keep_zip,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
shuffle_seed=shuffle_seed,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size)
super().__init__(
local=local,
remote=remote,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
epoch_size=epoch_size,
predownload=predownload,
cache_limit=cache_limit,
partition_algo=partition_algo,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size,
shuffle=shuffle,
shuffle_algo=shuffle_algo,
shuffle_seed=shuffle_seed,
shuffle_block_size=shuffle_block_size,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
)

self.tokenizer = tokenizer

Expand Down
13 changes: 12 additions & 1 deletion llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ class StreamingTextDataset(StreamingDataset):
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. Defaults to ``balanced``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
"""

def __init__(self,
Expand All @@ -91,6 +98,8 @@ def __init__(self,
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
**kwargs: Any):

group_method = kwargs.pop('group_method', None)
Expand Down Expand Up @@ -138,6 +147,8 @@ def __init__(self,
shuffle_seed=shuffle_seed,
shuffle_block_size=shuffle_block_size,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
Expand Down
14 changes: 14 additions & 0 deletions llmfoundry/models/inference_api_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper
from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAITokenizerWrapper)

__all__ = [
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
'OpenAITokenizerWrapper',
'InferenceAPIEvalWrapper',
]
110 changes: 110 additions & 0 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Optional

import torch
from composer.core.types import Batch
from composer.metrics import InContextLearningMetric
from composer.metrics.nlp import (InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity)
from composer.models import ComposerModel
from torchmetrics import Metric
from transformers import AutoTokenizer


class InferenceAPIEvalWrapper(ComposerModel):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
self.tokenizer = tokenizer
self.labels = None
# set up training and eval metrics
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
self.eval_metrics = {
metric.__class__.__name__: metric for metric in eval_metrics
}
super().__init__()

def get_metrics(self, is_train: bool = False):
if is_train:
raise NotImplementedError(
'You cannot use inference wrappers for training')
else:
metrics = self.eval_metrics

return metrics if metrics else {}

def get_next_token_logit_tensor(self,
prompt: str) -> Optional[torch.Tensor]:
raise NotImplementedError

def rebatch(self, batch: Batch):
# default is a no-op, but Chat API modifies these
return batch

def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
# If the batch mode is generate, we will generate a requested number of tokens using the underlying
# model's generate function. Extra generation kwargs can be passed in via the batch. Strings will
# be returned from eval_forward
output_logits_batch = []
for tokens, cont_idxs in zip(batch['input_ids'],
batch['continuation_indices']):

seqlen = tokens.shape[0]
tokens = tokens.tolist()
cont_idxs = cont_idxs.tolist()
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
output_logits = torch.nn.functional.one_hot(
torch.tensor(tokens[1:cont_idxs[0]]),
num_classes=self.tokenizer.vocab_size)
for i in range(len(expected_cont_tokens)):
# decode one token at a time
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] +
expected_cont_tokens[0:i])
next_logit_tensor = self.get_next_token_logit_tensor(prompt)
if next_logit_tensor is None:
continue
output_logits = torch.cat(
[output_logits,
next_logit_tensor.reshape(1, -1)])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
num_classes=self.tokenizer.vocab_size)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)

return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
batch = self.rebatch(batch)
self.labels = batch.pop('labels')
self.labels[:, :-1] = self.labels[:, 1:].clone()
self.labels[:, -1] = -100
if isinstance(metric, InContextLearningMetric) and batch.get(
'mode', None) == 'icl_task':
assert self.labels is not None
metric.update(batch, outputs, self.labels)
else:
raise NotImplementedError(
'Inference API wrapper only supports InContextLearningMetrics and mode=icl_task'
)

def forward(self):
raise NotImplementedError(
"Inference API wrapper doesn't support forward")

def loss(self):
raise NotImplementedError("Inference API wrapper doesn't support loss")
Loading

0 comments on commit 5445435

Please sign in to comment.