Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Oct 24, 2023
2 parents 0350950 + d72902a commit 27754e8
Show file tree
Hide file tree
Showing 34 changed files with 2,092 additions and 657 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ jobs:
container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.1.0-flash2'
container: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ If you have success/failure using LLM Foundry on other systems, please let us kn
|---------------------------|------------------|--------------|-------------------------------|
| A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported |
| H100-80GB | 1.13.1 | 11.7 | :x: Not Supported |
| H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported |
| H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported |
| A10-24GB | 1.13.1 | 11.7 | :construction: In Progress |
| A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress |
| MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress |
Expand All @@ -113,8 +115,11 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117
|-------------------------------------------------------------|----------------|--------------|-------------------------------------|
| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 | No |
| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 | No |
| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 | No |
| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 | Yes |
| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 | Yes |
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 | Yes (flash attention v1) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 | Yes (flash attention v2) |


# Installation
Expand Down
11 changes: 9 additions & 2 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import torch

try:
# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
import transformers

from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
Expand All @@ -14,8 +19,8 @@
ComposerHFT5)
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
flash_attn_fn, is_flash_v1_installed,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
Expand All @@ -24,6 +29,8 @@
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
if is_flash_v1_installed():
transformers.utils.is_flash_attn_available = lambda: False

except ImportError as e:
try:
Expand Down
121 changes: 16 additions & 105 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,30 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Periodically log generations to wandb from a set of prompts."""
from typing import Any, List, Union, cast
"""Deprecated Generate callback.
import torch
import wandb
from composer.core import Callback, State, get_precision_context
from composer.loggers import Logger, WandBLogger
from composer.utils import dist, ensure_tuple
Please use composer.callbacks.Generate instead.
"""
import warnings
from typing import Any, List, Union

from composer.callbacks import Generate as ComposerGenerate
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


class Generate(Callback):
class Generate(ComposerGenerate):

def __init__(self, prompts: List[str], batch_log_interval: int,
**kwargs: Any):
"""Periodically log generations to wandb from a set of prompts.
In the main view for a run, there will be a table that will show the _last_ logged generations.
To compare previous iterations of the generations, you need to
1. Click on the run
2. Click on "artifacts" in the menu on the left side of the screen
3. Click on one of the artifacts called "predictions"
4. Click on the "files" tab
5. Click on "predictions.table.json"
6. On the left hand side, there are different versions of the table produced throughout training. Select one of these.
7. Now, when you hover over other versions, there will be a "compare" button, which will allow you to compare the currently
selected version to the version you add via compare.
Args:
prompts (List[str]): The list of prompts you would like to produce generations for
batch_log_interval (int): The interval (in batches) at which this callback runs
kwargs: All kwargs well be passed along to the call to generate. This is for things like `do_sample`, `top_p`, etc
"""
self.prompts = prompts
self.batch_log_interval = batch_log_interval
self.generate_kwargs = kwargs
self.wandb_logger = None

def init(self, state: State, logger: Logger):
if dist.get_global_rank() == 0:
for destination in ensure_tuple(logger.destinations):
if isinstance(destination, WandBLogger):
self.wandb_logger = destination

def batch_checkpoint(self, state: State, logger: Logger) -> None:
if (state.timestamp.batch.value % self.batch_log_interval) == 0:
self.generate(state, logger)

def generate(self, state: State, logger: Logger) -> None:
model = state.model
original_mode = model.training
model.eval()
tokenizer = cast(Tokenizer, state.model.tokenizer)
device = state.device

if not hasattr(model.model, 'generate'):
raise ValueError(
f'Cannot generate from model {model.model.__class__.__name__} because it does not have a `generate` method'
)

# stash the original original value of padding_side because generation requires left padding
original_padding_side = tokenizer.padding_side
tokenizer.padding_side = 'left'
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenized_input = tokenizer(self.prompts,
return_tensors='pt',
padding=True)

for k, v in tokenized_input.items():
tokenized_input[k] = device.tensor_to_device(v)

# dummy forward call needed for FSDP to work consistently
dummy_input = torch.tensor([[0]], dtype=torch.long)
dummy_input = device.tensor_to_device(dummy_input)
with get_precision_context(state.precision):
with torch.no_grad():
assert isinstance(model.model, torch.nn.Module)
_ = model.model(input_ids=dummy_input)

output_token_ids = model.model.generate( # type: ignore
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
synced_gpus=True,
**self.generate_kwargs,
)

if dist.get_global_rank() == 0:
if self.wandb_logger is not None:
assert wandb.run is not None, 'wandb should have started run'

artifact = wandb.Artifact('generate_samples_' +
str(wandb.run.id),
type='predictions')

rows = []
for i in range(len(self.prompts)):
prompt = self.prompts[i]
output_tokens = output_token_ids[i][
tokenized_input['input_ids'].shape[1]:]
output_text = tokenizer.decode(output_tokens,
skip_special_tokens=True)

rows.append([prompt, output_text])

text_table = wandb.Table(data=rows,
columns=['prompt', 'generation'])
artifact.add(text_table, 'predictions')
wandb.log_artifact(artifact)
wandb.log({'generations': text_table},
step=state.timestamp.batch.value)
warnings.warn(
('Accessing llmfoundry.callbacks.generate_callback.Generate '
'is deprecated and will be removed in a future release. '
'Please use composer.callbacks.Generate instead.'),
DeprecationWarning,
)

tokenizer.padding_side = original_padding_side
model.train(mode=original_mode)
interval = f'{batch_log_interval}ba'
super().__init__(prompts=prompts, interval=interval, **kwargs)
Loading

0 comments on commit 27754e8

Please sign in to comment.