Skip to content

Commit

Permalink
Merge branch 'main' into convert_mds_script
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Sep 13, 2023
2 parents 3d7f9c9 + 0fdf43f commit 5b8cee5
Show file tree
Hide file tree
Showing 35 changed files with 326 additions and 368 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@


def get_configs(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str,
wandb_project: str):
wandb_project: str, git_repo: str, git_branch: str):
print(f'Running regression tests on {git_repo} {git_branch}.')
eval_7b_hf = RunConfig.from_file(
os.path.join(REGRESSIONS_DIR, 'eval-7b-hf.yaml'))
eval_7b_composer = RunConfig.from_file(
Expand Down Expand Up @@ -48,6 +49,8 @@ def get_configs(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str,
config.cluster = cluster
config.parameters['loggers'] = config.parameters.get('loggers', {})
config.parameters['loggers']['wandb'] = wandb_config
config.integrations[0]['git_repo'] = git_repo
config.integrations[0]['git_branch'] = git_branch

return all_configs, []

Expand All @@ -58,10 +61,13 @@ def get_configs(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str,
parser.add_argument('--mpt-7b-ckpt-path', type=str)
parser.add_argument('--wandb-entity', type=str)
parser.add_argument('--wandb-project', type=str)
parser.add_argument('--git-repo', type=str, default='mosaicml/llm-foundry')
parser.add_argument('--git-branch', type=str, default='main')

args = parser.parse_args()

run_configs, _ = get_configs(args.cluster, args.mpt_7b_ckpt_path,
args.wandb_entity, args.wandb_project)
args.wandb_entity, args.wandb_project,
args.git_repo, args.git_branch)
for run_config in run_configs:
run = create_run(run_config)
12 changes: 7 additions & 5 deletions llmfoundry/callbacks/eval_gauntlet_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Aggregate ICL evals into composite scores."""

import logging
import math
from enum import Enum
from typing import Optional
Expand All @@ -12,6 +13,8 @@

__all__ = ['EvalGauntlet']

log = logging.getLogger(__name__)


class Weighting(Enum):
EQUAL = 1
Expand Down Expand Up @@ -130,9 +133,8 @@ def eval_after_all(self, state: State, logger: Logger):
key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"

if key not in new_metrics:
print(
f"Warning: couldn't find results for benchmark: {benchmark}"
)
log.warning(
f'Could not find results for benchmark: {benchmark}.')
missing_metrics.append(key)
else:
score = new_metrics[key]
Expand All @@ -150,8 +152,8 @@ def eval_after_all(self, state: State, logger: Logger):
})

if len(missing_metrics) > 0:
print(
f"Removing category `{category['name']}` from gauntlet scores because benchmarks were missing: {missing_metrics}"
log.warning(
f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}"
)
del composite_scores[category['name']]
continue
Expand Down
4 changes: 1 addition & 3 deletions llmfoundry/callbacks/monolithic_ckpt_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
) if self.upload_to_object_store else contextlib.nullcontext(
enter_result=save_dir)
with dir_context_mgr as temp_save_dir:
save_path = str(
Path(temp_save_dir) / # type: ignore
Path(filename))
save_path = str(Path(temp_save_dir) / Path(filename))
dirname = os.path.dirname(save_path)
if dirname:
os.makedirs(dirname, exist_ok=True)
Expand Down
7 changes: 5 additions & 2 deletions llmfoundry/callbacks/resumption_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import List

from composer.core import Callback, State
Expand All @@ -11,6 +12,8 @@
'LayerFreezing',
]

log = logging.getLogger(__name__)


class GlobalLRScaling(Callback):
"""GlobalLRScaling.
Expand Down Expand Up @@ -38,7 +41,7 @@ def fit_start(self, state: State, logger: Logger):
group['weight_decay'] = group['lr'] * self.wd_pct
if 'initial_lr' in group:
group['initial_lr'] *= self.lr_scale
print(
log.info(
f"Set LR and WD to {group['lr']}, {group['weight_decay']}")

for scheduler in state.schedulers:
Expand Down Expand Up @@ -74,7 +77,7 @@ def fit_start(self, state: State, logger: Logger):
for name, p in state.model.named_parameters():
if p.requires_grad and name in self.layer_names:
p.requires_grad = False
print(f'Froze layer: {name}\nParam: {p}')
log.debug(f'Froze layer: {name}\nParam: {p}')
successful_freeze = True

if not successful_freeze:
Expand Down
1 change: 0 additions & 1 deletion llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset,

def __iter__(self) -> Iterable[Dict[str, bytes]]:
for sample in self.hf_dataset:
# print(sample)
# convert to bytes to store in MDS binary format
yield {'text': sample['text'].encode('utf-8')}

Expand Down
13 changes: 6 additions & 7 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,7 @@ def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray:
"""
span_markers = np.less(np.arange(total_tokens - 1), num_spans -
1)[np.random.permutation(total_tokens - 1)]
span_start_indicator = np.concatenate([[0],
span_markers]) # type: ignore
span_start_indicator = np.concatenate([np.array([0]), span_markers])
span_id = np.cumsum(span_start_indicator).reshape(-1, 1)
spans = np.arange(num_spans).reshape(1, -1)
span_lengths = np.sum(span_id == spans, axis=0)
Expand Down Expand Up @@ -715,13 +714,13 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],

# Ensure there's an end-of-sentence token at the end
if ensure_eos and (noised_tokens[-1] != eos_token_id):
noised_tokens = np.concatenate([noised_tokens,
[eos_token_id]]) # type: ignore
noised_tokens = np.concatenate(
[noised_tokens, np.array([eos_token_id])])

return noised_tokens

# Masking at previous token
prev_token_mask = np.concatenate([[0], mask[:-1]]) # type: ignore
prev_token_mask = np.concatenate([np.array([0]), mask[:-1]])

# Decompose mask into start-of-span mask and non-start-of-span mask
start_of_noise_span_token = np.logical_and(mask,
Expand All @@ -740,8 +739,8 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],

# Ensure there's an end-of-sentence token at the end
if ensure_eos and (noised_tokens[-1] != eos_token_id):
noised_tokens = np.concatenate([noised_tokens,
[eos_token_id]]) # type: ignore
noised_tokens = np.concatenate(
[noised_tokens, np.array([eos_token_id])])
return noised_tokens


Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
_validate_config(cfg.dataset)

# Use EOS as the pad token if none exists
if tokenizer.pad_token is None: # type: ignore
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

dataset = None # for pyright
Expand Down Expand Up @@ -323,7 +323,7 @@ def _build_hf_dataset_from_remote(
f'at {files_searched}'
) from e
else:
print(
log.debug(
f'Could not find {name}, looking for another extension')
continue

Expand All @@ -343,7 +343,7 @@ def _build_hf_dataset_from_remote(
dist.barrier()

cfg.dataset.hf_name = finetune_dir
print(cfg.dataset)
log.info(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
Expand Down
34 changes: 13 additions & 21 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
"""

import importlib
import logging
import os
import warnings
from typing import Any, Callable, Dict, Optional, Union
Expand All @@ -41,6 +42,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase

log = logging.getLogger(__name__)

__all__ = ['dataset_constructor']


Expand Down Expand Up @@ -205,16 +208,14 @@ def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]:

def get_preprocessing_fn_from_str(self,
preprocessor: Optional[str],
dataset_name: Optional[str] = None,
verbose: bool = False):
dataset_name: Optional[str] = None):
"""Get a preprocessing function from a string.
String can be either a registered function or an import path.
Args:
preprocessor (Optional[str]): The name of the preprocessing function, or an import path.
dataset_name (Optional[str]): The dataset name to look up in the registry.
verbose (bool): Whether to print verbose messages or not.
Returns:
Callable: The preprocessing function or None if not found.
Expand All @@ -226,33 +227,24 @@ def get_preprocessing_fn_from_str(self,
if dataset_name is None:
return None
if dataset_name in self._task_preprocessing_registry:
if verbose:
print(
f'Re-formatting dataset with "{dataset_name}" preprocessing function.'
)
log.info(
f'Re-formatting dataset with "{dataset_name}" preprocessing function.'
)
return self._task_preprocessing_registry[dataset_name]
else:
if verbose:
print(
'No preprocessor was supplied and no preprocessing function ' +\
log.info('No preprocessor was supplied and no preprocessing function ' +\
f'is registered for dataset name "{dataset_name}". No additional ' +\
'preprocessing will be applied. If the dataset is already formatted ' +\
'correctly, you can ignore this message.'
)
'correctly, you can ignore this message.')
return None
if preprocessor in self._task_preprocessing_registry:
if verbose:
print(
f'Re-formatting dataset with "{preprocessor}" preprocessing function.'
)
log.info(
f'Re-formatting dataset with "{preprocessor}" preprocessing function.'
)
return self._task_preprocessing_registry[preprocessor]

try:
import_path, function_name = preprocessor.split(':', maxsplit=1)
if verbose:
print(
f'Importing preprocessing function via: `from {import_path} import {function_name}`'
)
module = importlib.import_module(import_path)
preprocessing_fn = getattr(module, function_name)
except Exception as e:
Expand Down Expand Up @@ -289,7 +281,7 @@ def build_from_hf(
proto_preprocessing_fn)
else:
preprocessing_fn = self.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name, verbose=True)
proto_preprocessing_fn, dataset_name)

dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)

Expand Down
10 changes: 7 additions & 3 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,13 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
# build tokenizer
if 'tokenizer' not in cfg:
raise ValueError('config must define tokenizer')
tokenizer_cfg: Dict[str,
Any] = om.to_container(cfg.tokenizer,
resolve=True) # type: ignore

resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True)
if not isinstance(resolved_tokenizer_cfg, Dict):
raise ValueError(
'tokenizer config needs to be resolved by omegaconf into a Dict.')
tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg

tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
Expand Down
17 changes: 10 additions & 7 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import os
from itertools import islice
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence,
Union, cast)

import numpy as np
import torch
Expand Down Expand Up @@ -193,11 +194,12 @@ def __init__(
'`bos_token_id` if sequences start with a BOS token.'
)

self.split_token_id = eos_token_id
self.bos_mode = False
if eos_token_id is None:
self.split_token_id = bos_token_id
self.split_token_id = cast(int, bos_token_id)
self.bos_mode = True
else:
self.split_token_id = eos_token_id
self.bos_mode = False

def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:
batch = self.base_collator(examples)
Expand All @@ -206,8 +208,7 @@ def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:

def get_sequence_id_from_batch(
self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
is_separator = torch.eq(batch['input_ids'],
self.split_token_id) # type: ignore
is_separator = torch.eq(batch['input_ids'], self.split_token_id)
cumulative_sep = torch.cumsum(is_separator,
dim=1).to(batch['input_ids'].dtype)
# If separator token is bos, we're already done
Expand Down Expand Up @@ -340,7 +341,9 @@ def build_text_dataloader(
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

loader = build_text_dataloader(cfg, tokenizer, device_batch_size)
tokenizer = loader.dataset.tokenizer # type: ignore
assert isinstance(loader.dataset, StreamingTextDataset)
tokenizer = loader.dataset.tokenizer

for batch_ix, batch in enumerate(islice(loader, 5)):
print('\n')
print('#' * 20, f'Batch {batch_ix}', '#' * 20)
Expand Down
Loading

0 comments on commit 5b8cee5

Please sign in to comment.