diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml
index e16f2c8b40..1151837111 100644
--- a/.github/workflows/pr-gpu.yaml
+++ b/.github/workflows/pr-gpu.yaml
@@ -32,6 +32,10 @@ jobs:
container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
+ - name: 'gpu-2.1.0-flash2'
+ container: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest
+ markers: 'gpu'
+ pytest_command: 'coverage run -m pytest'
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
diff --git a/README.md b/README.md
index 00360a320c..04bad9c519 100644
--- a/README.md
+++ b/README.md
@@ -93,8 +93,10 @@ If you have success/failure using LLM Foundry on other systems, please let us kn
|---------------------------|------------------|--------------|-------------------------------|
| A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported |
| A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported |
+| A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported |
| H100-80GB | 1.13.1 | 11.7 | :x: Not Supported |
| H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported |
+| H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported |
| A10-24GB | 1.13.1 | 11.7 | :construction: In Progress |
| A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress |
| MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress |
@@ -113,8 +115,11 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117
|-------------------------------------------------------------|----------------|--------------|-------------------------------------|
| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 | No |
| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 | No |
+| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 | No |
| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 | Yes |
| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 | Yes |
+| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 | Yes (flash attention v1) |
+| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 | Yes (flash attention v2) |
# Installation
diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py
index 3bb9eed043..51fa67993a 100644
--- a/llmfoundry/__init__.py
+++ b/llmfoundry/__init__.py
@@ -4,6 +4,11 @@
import torch
try:
+ # Before importing any transformers models, we need to disable transformers flash attention if
+ # we are in an environment with flash attention version <2. Transformers hard errors on a not properly
+ # gated import otherwise.
+ import transformers
+
from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
@@ -14,8 +19,8 @@
ComposerHFT5)
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
- flash_attn_fn, scaled_multihead_dot_product_attention,
- triton_flash_attn_fn)
+ flash_attn_fn, is_flash_v1_installed,
+ scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
@@ -24,6 +29,8 @@
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
+ if is_flash_v1_installed():
+ transformers.utils.is_flash_attn_available = lambda: False
except ImportError as e:
try:
diff --git a/llmfoundry/callbacks/generate_callback.py b/llmfoundry/callbacks/generate_callback.py
index bb5b557d37..58ba7e685e 100644
--- a/llmfoundry/callbacks/generate_callback.py
+++ b/llmfoundry/callbacks/generate_callback.py
@@ -1,119 +1,30 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
-"""Periodically log generations to wandb from a set of prompts."""
-from typing import Any, List, Union, cast
+"""Deprecated Generate callback.
-import torch
-import wandb
-from composer.core import Callback, State, get_precision_context
-from composer.loggers import Logger, WandBLogger
-from composer.utils import dist, ensure_tuple
+Please use composer.callbacks.Generate instead.
+"""
+import warnings
+from typing import Any, List, Union
+
+from composer.callbacks import Generate as ComposerGenerate
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
-class Generate(Callback):
+class Generate(ComposerGenerate):
def __init__(self, prompts: List[str], batch_log_interval: int,
**kwargs: Any):
- """Periodically log generations to wandb from a set of prompts.
-
- In the main view for a run, there will be a table that will show the _last_ logged generations.
- To compare previous iterations of the generations, you need to
- 1. Click on the run
- 2. Click on "artifacts" in the menu on the left side of the screen
- 3. Click on one of the artifacts called "predictions"
- 4. Click on the "files" tab
- 5. Click on "predictions.table.json"
- 6. On the left hand side, there are different versions of the table produced throughout training. Select one of these.
- 7. Now, when you hover over other versions, there will be a "compare" button, which will allow you to compare the currently
- selected version to the version you add via compare.
-
- Args:
- prompts (List[str]): The list of prompts you would like to produce generations for
- batch_log_interval (int): The interval (in batches) at which this callback runs
- kwargs: All kwargs well be passed along to the call to generate. This is for things like `do_sample`, `top_p`, etc
- """
- self.prompts = prompts
- self.batch_log_interval = batch_log_interval
- self.generate_kwargs = kwargs
- self.wandb_logger = None
-
- def init(self, state: State, logger: Logger):
- if dist.get_global_rank() == 0:
- for destination in ensure_tuple(logger.destinations):
- if isinstance(destination, WandBLogger):
- self.wandb_logger = destination
-
- def batch_checkpoint(self, state: State, logger: Logger) -> None:
- if (state.timestamp.batch.value % self.batch_log_interval) == 0:
- self.generate(state, logger)
-
- def generate(self, state: State, logger: Logger) -> None:
- model = state.model
- original_mode = model.training
- model.eval()
- tokenizer = cast(Tokenizer, state.model.tokenizer)
- device = state.device
-
- if not hasattr(model.model, 'generate'):
- raise ValueError(
- f'Cannot generate from model {model.model.__class__.__name__} because it does not have a `generate` method'
- )
-
- # stash the original original value of padding_side because generation requires left padding
- original_padding_side = tokenizer.padding_side
- tokenizer.padding_side = 'left'
- if tokenizer.pad_token_id is None:
- tokenizer.pad_token_id = tokenizer.eos_token_id
- tokenized_input = tokenizer(self.prompts,
- return_tensors='pt',
- padding=True)
-
- for k, v in tokenized_input.items():
- tokenized_input[k] = device.tensor_to_device(v)
-
- # dummy forward call needed for FSDP to work consistently
- dummy_input = torch.tensor([[0]], dtype=torch.long)
- dummy_input = device.tensor_to_device(dummy_input)
- with get_precision_context(state.precision):
- with torch.no_grad():
- assert isinstance(model.model, torch.nn.Module)
- _ = model.model(input_ids=dummy_input)
-
- output_token_ids = model.model.generate( # type: ignore
- input_ids=tokenized_input['input_ids'],
- attention_mask=tokenized_input['attention_mask'],
- synced_gpus=True,
- **self.generate_kwargs,
- )
-
- if dist.get_global_rank() == 0:
- if self.wandb_logger is not None:
- assert wandb.run is not None, 'wandb should have started run'
-
- artifact = wandb.Artifact('generate_samples_' +
- str(wandb.run.id),
- type='predictions')
-
- rows = []
- for i in range(len(self.prompts)):
- prompt = self.prompts[i]
- output_tokens = output_token_ids[i][
- tokenized_input['input_ids'].shape[1]:]
- output_text = tokenizer.decode(output_tokens,
- skip_special_tokens=True)
-
- rows.append([prompt, output_text])
- text_table = wandb.Table(data=rows,
- columns=['prompt', 'generation'])
- artifact.add(text_table, 'predictions')
- wandb.log_artifact(artifact)
- wandb.log({'generations': text_table},
- step=state.timestamp.batch.value)
+ warnings.warn(
+ ('Accessing llmfoundry.callbacks.generate_callback.Generate '
+ 'is deprecated and will be removed in a future release. '
+ 'Please use composer.callbacks.Generate instead.'),
+ DeprecationWarning,
+ )
- tokenizer.padding_side = original_padding_side
- model.train(mode=original_mode)
+ interval = f'{batch_log_interval}ba'
+ super().__init__(prompts=prompts, interval=interval, **kwargs)
diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py
index 492816ea07..aa3beda513 100644
--- a/llmfoundry/callbacks/hf_checkpointer.py
+++ b/llmfoundry/callbacks/hf_checkpointer.py
@@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
-import json
+import copy
import logging
import os
import tempfile
@@ -10,14 +10,14 @@
from typing import Optional, Union
import torch
-from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.core.state import fsdp_state_dict_type_context
-from composer.loggers import Logger
+from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
-from transformers import PreTrainedTokenizerBase
+from composer.utils.misc import create_interval_scheduler
+from transformers import PreTrainedModel, PreTrainedTokenizerBase
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
@@ -39,6 +39,11 @@ class HuggingFaceCheckpointer(Callback):
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
+ mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not
+ be registered. Default is ``None``.
+ mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call.
+ Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and
+ ``{'task': 'llm/v1/completions'}`` respectively.
"""
def __init__(
@@ -48,6 +53,8 @@ def __init__(
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = False,
+ mlflow_registered_model_name: Optional[str] = None,
+ mlflow_logging_config: Optional[dict] = None,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
@@ -58,6 +65,22 @@ def __init__(
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]
+
+ # mlflow config setup
+ self.mlflow_registered_model_name = mlflow_registered_model_name
+ if mlflow_logging_config is None:
+ mlflow_logging_config = {}
+ if self.mlflow_registered_model_name is not None:
+ # Both the metadata and the task are needed in order for mlflow
+ # and databricks optimized model serving to work
+ if 'metadata' not in mlflow_logging_config:
+ mlflow_logging_config['metadata'] = {
+ 'task': 'llm/v1/completions'
+ }
+ if 'task' not in mlflow_logging_config:
+ mlflow_logging_config['task'] = 'text-generation'
+ self.mlflow_logging_config = mlflow_logging_config
+
self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
self.check_interval = create_interval_scheduler(
@@ -71,6 +94,7 @@ def __init__(
self.remote_ud = None
self.last_checkpoint_batch: Optional[Time] = None
+ self.mlflow_loggers = []
def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
@@ -87,6 +111,23 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)
+ if self.mlflow_registered_model_name is not None:
+ self.mlflow_loggers = [
+ logger_destination
+ for logger_destination in logger.destinations
+ if isinstance(logger_destination, MLFlowLogger)
+ ]
+ if len(self.mlflow_loggers) == 0:
+ raise ValueError(
+ f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. '
+ +
+ 'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.'
+ )
+
+ import mlflow
+ mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
+ '5GB')
+
def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused
@@ -99,8 +140,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
- assert isinstance(state.model, HuggingFaceModel)
-
save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
@@ -114,9 +153,29 @@ def _save_checkpoint(self, state: State, logger: Logger):
assert isinstance(temp_save_dir,
str) # pyright doesn't know about enter_result
- with fsdp_state_dict_type_context(state.model.model,
- state_dict_type='full'):
- state_dict = state.model.model.state_dict()
+ log.debug('Gathering state dict')
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+ if state.is_model_ddp:
+ original_model: PreTrainedModel = state.model.module.model
+ state_dict_model = state.model.module.model
+ original_tokenizer = state.model.module.tokenizer
+ elif isinstance(state.model.model, FSDP):
+ original_model: PreTrainedModel = state.model.model.module
+ state_dict_model = state.model.model
+ original_tokenizer = state.model.tokenizer
+ else:
+ original_model: PreTrainedModel = state.model.model
+ state_dict_model = state.model.model
+ original_tokenizer = state.model.tokenizer
+
+ state_dict_context = fsdp_state_dict_type_context(
+ original_model, state_dict_type='full') if (
+ (not state.is_model_ddp) and isinstance(
+ state_dict_model, FSDP)) else contextlib.nullcontext()
+
+ with state_dict_context:
+ state_dict = state_dict_model.state_dict()
# convert the state dict to the requested precision
for k, v in state_dict.items():
@@ -124,34 +183,35 @@ def _save_checkpoint(self, state: State, logger: Logger):
state_dict[k] = v.to(dtype=self.dtype)
if dist.get_global_rank() == 0:
- # We raise above if the model is not a HuggingFaceModel, so this assert is safe
- assert hasattr(state.model.model, 'save_pretrained')
- state.model.model.save_pretrained(temp_save_dir,
- state_dict=state_dict)
-
- if state.model.tokenizer is not None:
- assert isinstance(state.model.tokenizer,
+ log.debug('Saving Hugging Face checkpoint to disk')
+
+ copied_config = copy.deepcopy(original_model.config)
+ if copied_config.model_type == 'mpt':
+ copied_config.attn_config['attn_impl'] = 'torch'
+ copied_config.init_device = 'cpu'
+
+ # TODO: after torch 2.1, we can load a state dict into a meta model
+ # and skip the extra model init
+ log.debug(f'Creating new model instance')
+ new_model_instance = type(original_model)(copied_config)
+ new_model_instance.to(dtype=self.dtype)
+ new_model_instance.load_state_dict(state_dict)
+ del state_dict
+
+ log.debug('Saving Hugging Face checkpoint to disk')
+ new_model_instance.save_pretrained(temp_save_dir)
+ if original_tokenizer is not None:
+ assert isinstance(original_tokenizer,
PreTrainedTokenizerBase)
- state.model.tokenizer.save_pretrained(temp_save_dir)
+ original_tokenizer.save_pretrained(temp_save_dir)
# Only need to edit files for MPT because it has custom code
- if state.model.model.config.model_type == 'mpt':
+ if original_model.config.model_type == 'mpt':
+ log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)
- with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
- edited_config = json.load(f)
-
- if state.model.model.config.model_type == 'mpt':
- edited_config['attn_config']['attn_impl'] = 'torch'
- edited_config['init_device'] = 'cpu'
-
- edited_config['torch_dtype'] = self.precision
- with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
- json.dump(edited_config, f, indent=4)
-
if self.upload_to_object_store:
assert self.remote_ud is not None
- # TODO change to log after other pr
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
@@ -164,4 +224,31 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)
- dist.barrier()
+ elapsed_duration = state.get_elapsed_duration()
+ if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
+ components = {'model': new_model_instance}
+ if original_tokenizer is not None:
+ components['tokenizer'] = original_tokenizer
+
+ log.debug('Logging Hugging Face model to MLFlow')
+ for i, mlflow_logger in enumerate(self.mlflow_loggers):
+ log.debug(
+ f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}'
+ )
+ local_save_path = str(
+ Path(temp_save_dir) / f'mlflow_save_{i}')
+
+ # TODO: Remove after mlflow fixes the bug that makes this necessary
+ import mlflow
+ mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
+ mlflow_logger.save_model(
+ flavor='transformers',
+ transformers_model=components,
+ path=local_save_path,
+ **self.mlflow_logging_config,
+ )
+ mlflow_logger.register_model(
+ model_uri=local_save_path,
+ name=self.mlflow_registered_model_name,
+ await_registration_for=3600,
+ )
diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py
index d685d0077d..bc41945076 100644
--- a/llmfoundry/data/denoising.py
+++ b/llmfoundry/data/denoising.py
@@ -10,13 +10,15 @@
import numpy as np
import torch
+from composer.core.data_spec import DataSpec
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase
from llmfoundry.data.packing import BinPackWrapper
-from llmfoundry.data.text_data import StreamingTextDataset
+from llmfoundry.data.text_data import (StreamingTextDataset,
+ get_tokens_per_batch_func)
from llmfoundry.models import utils
__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
@@ -353,7 +355,7 @@ def build_text_denoising_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
-) -> DataLoader[Dict]:
+) -> DataSpec:
"""Constructor function for a Mixture of Denoisers dataloader.
This function constructs a dataloader that can be used to train an
@@ -506,7 +508,7 @@ def build_text_denoising_dataloader(
'but cfg.dataset.packing_ratio has not been set. Please set ' +\
'the latter to turn on packing or remove the former from the config.')
- return DataLoader(
+ dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
@@ -518,6 +520,12 @@ def build_text_denoising_dataloader(
timeout=cfg.get('timeout', 0),
)
+ token_counting_func = get_tokens_per_batch_func(
+ pad_token_id=tokenizer.pad_token_id,
+ decoder_only=cfg.mixture_of_denoisers.decoder_only_format)
+
+ return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
+
def noise_token_sequence(
example: Union[torch.Tensor, Mapping[str, Any]],
@@ -869,7 +877,9 @@ def _format_tokens_for_decoder_only(
tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_kwargs=tokenizer_kwargs)
- loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
+ loader = build_text_denoising_dataloader(cfg, tokenizer,
+ device_batch_size).dataloader
+ assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingTextDataset)
print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py
index ebb7991dde..2dde563ac6 100644
--- a/llmfoundry/data/finetuning/dataloader.py
+++ b/llmfoundry/data/finetuning/dataloader.py
@@ -6,6 +6,7 @@
import datasets as hf_datasets
import torch
+from composer.core.data_spec import DataSpec
from composer.utils import dist, get_file, parse_uri
from omegaconf import DictConfig
from torch.utils.data import DataLoader
@@ -14,6 +15,7 @@
from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.packing import BinPackWrapper
+from llmfoundry.data.text_data import get_tokens_per_batch_func
log = logging.getLogger(__name__)
@@ -23,7 +25,7 @@
def build_finetuning_dataloader(cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
- device_batch_size: int) -> DataLoader:
+ device_batch_size: int) -> DataSpec:
"""Builds a finetuning dataloader for training or evaluating.
The underlying dataset can be built through one of two code paths:
@@ -143,7 +145,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
- return DataLoader(
+ dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
@@ -193,7 +195,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)
assert dataset is not None
- return DataLoader(
+ dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
@@ -208,6 +210,11 @@ def build_finetuning_dataloader(cfg: DictConfig,
timeout=cfg.get('timeout', 0),
)
+ token_counting_func = get_tokens_per_batch_func(
+ pad_token_id=tokenizer.pad_token_id)
+
+ return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
+
def _validate_config(dataset_cfg: DictConfig) -> None:
"""Validates the dataset configuration.
@@ -442,7 +449,8 @@ def _build_collate_fn(
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
device_batch_size = 2
- dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
+ dataloader = build_finetuning_dataloader(cfg, tokenizer,
+ device_batch_size).dataloader
packing = cfg.dataset.get('packing_ratio') is not None
diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py
index d0a73be801..1532de276e 100644
--- a/llmfoundry/data/packing.py
+++ b/llmfoundry/data/packing.py
@@ -377,7 +377,7 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.dataset.max_leftovers_to_keep = None
train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
- max(raw_batch_sizes) * 100)
+ max(raw_batch_sizes) * 100).dataloader
# Get a bunch of raw examples
big_batch = next(iter(train_dataloader))
diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py
index afdd243adf..93af2f63ed 100644
--- a/llmfoundry/data/text_data.py
+++ b/llmfoundry/data/text_data.py
@@ -11,6 +11,8 @@
import numpy as np
import torch
import transformers
+from composer.core.data_spec import DataSpec
+from composer.core.types import Batch
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import Stream, StreamingDataset
@@ -237,7 +239,7 @@ def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
-) -> DataLoader:
+) -> DataSpec:
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
if cfg.dataset.get('group_method', None) is not None:
raise NotImplementedError(
@@ -281,7 +283,7 @@ def build_text_dataloader(
eos_token_id=eos_token_id,
bos_token_id=bos_token_id)
- return DataLoader(
+ dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=device_batch_size,
@@ -293,6 +295,58 @@ def build_text_dataloader(
timeout=cfg.get('timeout', 0),
)
+ # If we pretokenized, we may not have padding, in which case the
+ # tokenizer may not have a pad_token_id. In this case, we can
+ # just use the default token counting function. This is correct
+ # because we do not support training on pretokenized data with padding,
+ # and if tokenizing on the fly, we require that the tokenizer has a pad token.
+ token_counting_func = None
+ if tokenizer.pad_token_id is not None:
+ token_counting_func = get_tokens_per_batch_func(
+ pad_token_id=tokenizer.pad_token_id)
+
+ return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
+
+
+def get_tokens_per_batch_func(pad_token_id: int,
+ decoder_only: bool = True
+ ) -> Callable[[Batch], int]:
+ """Returns a callable that counts the number of tokens in a batch.
+
+ Args:
+ pad_token_id (int): The id of the padding token.
+ decoder_only (bool, optional): Whether to expect the batch to just contain ``input_ids`` (decoder only)
+ or to also contain ``decoder_input_ids`` (encoder decoder). Defaults to ``True``.
+
+ Returns:
+ Callable[[Batch], int]: A callable that counts the number of tokens in a batch.
+ """
+
+ def get_num_samples_in_batch(batch: Batch) -> int:
+ if not isinstance(batch, Mapping) or 'input_ids' not in batch:
+ raise ValueError(
+ 'get_tokens_per_batch_func() requires a batch with an input_ids key'
+ )
+
+ if not decoder_only and 'decoder_input_ids' not in batch:
+ raise ValueError(
+ 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_input_ids key'
+ )
+
+ # Count number of non padding tokens in batch
+ input_ids_tokens = int(
+ torch.sum(batch['input_ids'] != pad_token_id).item())
+
+ # For encoder decoder models only
+ decoder_input_ids_tokens = 0
+ if not decoder_only:
+ decoder_input_ids_tokens = int(
+ torch.sum(batch['decoder_input_ids'] != pad_token_id).item())
+
+ return input_ids_tokens + decoder_input_ids_tokens
+
+ return get_num_samples_in_batch
+
# Helpful to test if your dataloader is working locally
# Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out
@@ -353,7 +407,8 @@ def build_text_dataloader(
tokenizer_kwargs = {'model_max_length': args.max_seq_len}
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
- loader = build_text_dataloader(cfg, tokenizer, device_batch_size)
+ loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader
+ assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingTextDataset)
tokenizer = loader.dataset.tokenizer
diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py
index 923f658c0a..ac1d63709c 100644
--- a/llmfoundry/models/hf/hf_causal_lm.py
+++ b/llmfoundry/models/hf/hf_causal_lm.py
@@ -25,8 +25,7 @@
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
-from llmfoundry.models.layers.llama_attention_monkeypatch import \
- get_llama_attention_patch_fn
+from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
try:
@@ -97,12 +96,28 @@ def __init__(self, om_model_config: Union[DictConfig,
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
+ use_flash_attention_2 = om_model_config.get('use_flash_attention_2',
+ False)
+ if use_flash_attention_2 and not is_flash_v2_installed():
+ raise ValueError(
+ 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ + 'Please install flash_attn==2.3.2`.')
+
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)
+ # This is not how you are supposed to set this, but transformers currently only
+ # supports enabling flash attention 2 when using the from_pretrained API.
+ # We need to support it for both from_pretrained and from_config, so we have to
+ # set the private attribute here. This will just skip all of transformers'
+ # validation logic that it is ok to use flash attention 2, so we check
+ # whether it is installed above, and whether the chosen config supports it here.
+ # https://github.com/huggingface/transformers/issues/26878
+ config._flash_attn_2_enabled = use_flash_attention_2
+
# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
@@ -202,6 +217,9 @@ def __init__(self, om_model_config: Union[DictConfig,
)
from transformers.models.llama.modeling_llama import \
LlamaAttention
+
+ from llmfoundry.models.layers.llama_attention_monkeypatch import \
+ get_llama_attention_patch_fn
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py
index 17b66dc569..9b91e70b2f 100644
--- a/llmfoundry/models/mpt/modeling_mpt.py
+++ b/llmfoundry/models/mpt/modeling_mpt.py
@@ -423,6 +423,7 @@ def forward(
)
# initialize the past key values cache if it should be used
+ presents = () if use_cache else None
if use_cache and past_key_values is None:
past_key_values = [() for _ in range(self.config.n_layers)
] # type: ignore
@@ -435,7 +436,7 @@ def forward(
all_hidden_states = all_hidden_states + (x,)
past_key_value = (past_key_values[b_idx]
if past_key_values is not None else None)
- x, attn_weights, past_key_value = block(
+ x, attn_weights, present = block(
x,
past_key_value=past_key_value,
attn_bias=attn_bias,
@@ -443,8 +444,8 @@ def forward(
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
)
- if past_key_values is not None:
- past_key_values[b_idx] = past_key_value
+ if presents is not None:
+ presents += (present,)
if output_attentions:
assert all_self_attns is not None # pyright
@@ -459,7 +460,7 @@ def forward(
return BaseModelOutputWithPast(
last_hidden_state=x,
- past_key_values=past_key_values,
+ past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py
new file mode 100644
index 0000000000..4a6d21c873
--- /dev/null
+++ b/llmfoundry/optim/scheduler.py
@@ -0,0 +1,159 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""Experimental learning rate schedulers used for training LLMs."""
+
+import textwrap
+import warnings
+from typing import Union
+
+from composer.core import State, Time, TimeUnit
+from composer.optim import ComposerScheduler, LinearScheduler
+from composer.optim.scheduler import _convert_time
+
+__all__ = ['InverseSquareRootWithWarmupScheduler']
+
+
+def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time],
+ name: str) -> None:
+ if isinstance(time, str):
+ time = Time.from_timestring(time)
+ if isinstance(t_max, str):
+ t_max = Time.from_timestring(t_max)
+
+ assert not isinstance(time, str) and not isinstance(t_max, str)
+
+ if time.unit != t_max.unit:
+ raise ValueError(f'{time.unit=} does not match {t_max.unit=}.')
+
+
+def _raise_if_units_dur(time: Union[str, Time], name: str) -> None:
+ if isinstance(time, str):
+ time = Time.from_timestring(time)
+
+ assert not isinstance(time, str)
+
+ if time.unit == TimeUnit('dur'):
+ raise ValueError(f'{name} cannot be in units of "dur".')
+
+
+class InverseSquareRootWithWarmupScheduler(ComposerScheduler):
+ r"""Inverse square root LR decay with warmup and optional linear cooldown.
+
+ Specifically, the learning rate multiplier :math:`\alpha(t)` can be expressed as:
+
+ .. math::
+ \alpha(t) = \begin{cases}
+ t / t_{warmup}, & \text{if } t < t_{warmup} \\
+ \alpha_{f,decay} + \frac{1 - \alpha_{f,decay}}{\sqrt{\tau_d}}, & \text{if } t_{warmup} <= t < t_{max} - t_{cooldown} \\
+ \alpha_i + (alpha_{f,cooldown} - \alpha_i) \times \tau_c, & \text{otherwise}
+ \end{cases}
+
+ Given :math:`\tau_d`, the time elapsed during the inverse square root decay (normalized by :math:`t_scale`), as:
+
+ .. math::
+ \tau_d = (t - t_{warmup} + t_{scale}) / {t_scale}
+
+ :math:`\alpha_i` as the value of the learning rate multiplier when :math:`\tau_d` is evaluated at :math:`t = t_{max} - t_{cooldown}`,
+ and :math:`\tau_c`, the fraction of linear cooldown time elapsed (clipped to the interval :math:`[0, 1]`), as:
+
+ .. math::
+ \tau_c = (t - t_{max} + t_{cooldown}) / t_{cooldown}
+
+ Where :math:`t_{warmup}` represents the warmup time, :math:`t_{scale}` represents the time scale,
+ :math:`t_{cooldown}` represents the cooldown time, :math:`t_{max}` represents the duration of this scheduler,
+ :math:`\alpha_{f,decay}` represents the learning rate multiplier that the inverse square root decays to at infinite time,
+ and :math:`\alpha_{f,cooldown}` represents the learning rate multiplier that the linear cooldown decays to.
+
+ Note, :math:`\alpha_{f,decay} >= \alpha_{f,cooldown}` to ensure that the learning rate is monotonically decreasing after warmup.
+
+ Also note, ``t_warmup``, ``t_scale``, and ``t_cooldown`` cannot be specified in units of duration; since this schedule is designed for continual learning,
+ ``max_duration`` is expected to change. Instead, these parameters need to be specified in the same units as ``max_duration`` passed to the trainer.
+
+ Args:
+ t_warmup (str | Time): The warmup time.
+ t_scale (str | Time): The time scale.
+ t_cooldown (str | Time): The cooldown time.
+ t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``.
+ alpha_f_decay (float): The learning rate multiplier to decay inverse square root decay to. Default = ``0.0``.
+ alpha_f_cooldown (float): The learning rate multiplier to decay linear cooldown to. Default = ``0.0``.
+ """
+
+ def __init__(self,
+ t_warmup: Union[str, Time],
+ t_scale: Union[str, Time],
+ t_cooldown: Union[str, Time],
+ t_max: Union[str, Time] = '1dur',
+ alpha_f_decay: float = 0.0,
+ alpha_f_cooldown: float = 0.0) -> None:
+ if alpha_f_decay < alpha_f_cooldown:
+ raise ValueError(('Required: alpha_f_decay >= alpha_f_cooldown. '
+ f'Current: alpha_f_decay={alpha_f_decay}, '
+ f'alpha_f_cooldown={alpha_f_cooldown}.'))
+ _raise_if_units_dur(t_warmup, 't_warmup')
+ _raise_if_units_dur(t_scale, 't_scale')
+ _raise_if_units_dur(t_cooldown, 't_cooldown')
+ self.t_warmup = t_warmup
+ self.t_scale = t_scale
+ self.t_cooldown = t_cooldown
+ self.t_max = t_max
+ self.alpha_f_decay = alpha_f_decay
+ self.alpha_f_cooldown = alpha_f_cooldown
+ self.warmup_scheduler = LinearScheduler(alpha_i=0.0,
+ alpha_f=1.0,
+ t_max=t_warmup)
+
+ def __call__(self, state: State, ssr: float = 1.0) -> float:
+ assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'
+ _raise_if_units_dont_match(self.t_warmup, state.max_duration,
+ 't_warmup')
+ _raise_if_units_dont_match(self.t_scale, state.max_duration, 't_scale')
+ _raise_if_units_dont_match(self.t_cooldown, state.max_duration,
+ 't_cooldown')
+
+ t_warmup = _convert_time(self.t_warmup, state)
+ if t_warmup.value == 0:
+ warnings.warn(
+ textwrap.dedent("""\
+ The warmup duration is 0. If warmup was specified as a fraction of the total
+ training duration, the warmup duration is calculated in the
+ same unit as the trainer's max_duration parameter."""))
+
+ if state.timestamp < t_warmup:
+ return self.warmup_scheduler(state)
+
+ t_scale = _convert_time(self.t_scale, state, ssr=ssr)
+ t_cooldown = _convert_time(self.t_cooldown, state, ssr=ssr)
+ t_max = _convert_time(self.t_max, state, ssr=ssr)
+ current_time = state.timestamp.get(t_scale.unit)
+
+ t_shift = t_scale - t_warmup
+ # t_cooldown_start is max of t_warmup, t_max - t_cooldown
+ t_cooldown_start = t_max - t_cooldown
+ if t_cooldown_start < t_warmup:
+ t_cooldown_start = t_warmup
+
+ if state.timestamp < t_cooldown_start:
+ # Rescale LR by a coefficient equal to the inverse square root of the time
+ # elapsed after warmup, rescaled by the time scale, such that, at
+ # infinite time, the LR decays to alpha_f_decay.
+ coeff = 1 / ((current_time + t_shift) / t_scale).value**0.5
+ current_factor = (self.alpha_f_decay + coeff *
+ (1.0 - self.alpha_f_decay))
+ return current_factor
+
+ else:
+ coeff = 1 / ((t_cooldown_start + t_shift) / t_scale).value**0.5
+ alpha_i = self.alpha_f_decay + coeff * (1.0 - self.alpha_f_decay)
+
+ if t_cooldown.value == 0:
+ return alpha_i
+
+ # Linearly decay the LR from its value at the step at which cooldown
+ # started to alpha_f_cooldown over t_cooldown time.
+ frac_of_cooldown = ((current_time - t_cooldown_start) /
+ t_cooldown).value
+ frac_of_cooldown = min(1.0, frac_of_cooldown)
+ current_factor = (alpha_i + frac_of_cooldown *
+ (self.alpha_f_cooldown - alpha_i))
+ return current_factor
diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py
index 001be6a030..45192e09dd 100644
--- a/llmfoundry/tokenizers/tiktoken.py
+++ b/llmfoundry/tokenizers/tiktoken.py
@@ -21,6 +21,7 @@ def __init__(self,
model_name: Optional[str] = None,
encoding_name: Optional[str] = None,
add_bos_token: bool = False,
+ add_eos_token: bool = False,
unk_token: Optional[str] = '<|endoftext|>',
eos_token: Optional[str] = '<|endoftext|>',
bos_token: Optional[str] = '<|endoftext|>',
@@ -36,6 +37,7 @@ def __init__(self,
encoding_name (Optional[str], optional): The name of the encoding to load from tiktoken. Defaults to None.
Either model_name or encoding_name must be set, but not both.
add_bos_token (bool, optional): Whether to add bos tokens. Defaults to False.
+ add_eos_token (bool, optional): Whether to add eos tokens. Defaults to False.
unk_token (Optional[str], optional): The unk token. Defaults to '<|endoftext|>'.
eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'.
bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'.
@@ -66,10 +68,12 @@ def __init__(self,
'You need to specify either model_name or encoding_name.')
self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
super().__init__(model_name=model_name,
encoding_name=encoding_name,
add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
unk_token=unk_token,
eos_token=eos_token,
bos_token=bos_token,
@@ -151,7 +155,7 @@ def convert_ids_to_tokens(
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
- return self.added_tokens_decoder[ids]
+ return str(self.added_tokens_decoder[ids])
return self._convert_id_to_token(ids)
@@ -167,7 +171,7 @@ def convert_ids_to_tokens(
if index in self.added_tokens_decoder:
tokens.append(self.encoding.decode(current_stream))
current_stream = []
- tokens.append(self.added_tokens_decoder[index])
+ tokens.append(str(self.added_tokens_decoder[index]))
else:
current_stream.append(index)
@@ -179,17 +183,15 @@ def build_inputs_with_special_tokens(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
- if self.add_bos_token:
- bos_token_ids = [self.bos_token_id]
- else:
- bos_token_ids = []
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
- output = bos_token_ids + token_ids_0
+ output = bos_token_id + token_ids_0 + eos_token_id
- if token_ids_1 is None:
- return output
+ if token_ids_1 is not None:
+ output = output + bos_token_id + token_ids_1 + eos_token_id
- return output + bos_token_ids + token_ids_1
+ return output
def get_special_tokens_mask(
self,
@@ -221,15 +223,13 @@ def get_special_tokens_mask(
token_ids_1=token_ids_1,
already_has_special_tokens=True)
- if not self.add_bos_token:
- return super().get_special_tokens_mask(
- token_ids_0=token_ids_0,
- token_ids_1=token_ids_1,
- already_has_special_tokens=False)
+ bos_token_id = [1] if self.add_bos_token else []
+ eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
- return [1] + ([0] * len(token_ids_0))
- return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+ return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id +
+ bos_token_id + ([0] * len(token_ids_1)) + eos_token_id)
def create_token_type_ids_from_sequences(
self,
diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py
index b58df23302..2f46ed0653 100644
--- a/llmfoundry/utils/builders.py
+++ b/llmfoundry/utils/builders.py
@@ -3,13 +3,14 @@
import logging
import os
+import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from composer import algorithms
-from composer.callbacks import (EarlyStopper, LRMonitor, MemoryMonitor,
- OptimizerMonitor, RuntimeEstimator,
- SpeedMonitor)
+from composer.callbacks import (EarlyStopper, Generate, LRMonitor,
+ MemoryMonitor, OptimizerMonitor,
+ RuntimeEstimator, SpeedMonitor)
from composer.core import Algorithm, Callback, Evaluator
from composer.datasets.in_context_learning_evaluation import \
get_icl_task_dataloader
@@ -26,12 +27,13 @@
from torch.optim.optimizer import Optimizer
from transformers import AutoTokenizer, PreTrainedTokenizerBase
-from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, Generate,
- GlobalLRScaling, HuggingFaceCheckpointer,
- LayerFreezing, MonolithicCheckpointSaver,
+from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling,
+ HuggingFaceCheckpointer, LayerFreezing,
+ MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW, DecoupledLionW_8bit)
+from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
log = logging.getLogger(__name__)
@@ -89,7 +91,21 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
'log_optimizer_metrics', True),)
elif name == 'generate_callback':
prompts = kwargs.pop('prompts')
- return Generate(prompts=list(prompts), **kwargs)
+ interval = kwargs.pop('interval', None)
+ # Generate callback used to be batch_log_interval, so this is for backwards compatibility
+ if interval is None:
+ batch_log_interval: str = kwargs.pop('batch_log_interval', '')
+ if batch_log_interval:
+ interval = f'{batch_log_interval}ba'
+ warnings.warn(
+ ('generate_callback.batch_log_interval is deprecated and will be removed in a future release.'
+ f'Please use interval: {interval}'),
+ DeprecationWarning,
+ )
+ else:
+ raise KeyError(
+ '"interval" must be specified with generate callback')
+ return Generate(prompts=list(prompts), interval=interval, **kwargs)
elif name == 'global_lr_scaling':
return GlobalLRScaling(**kwargs)
elif name == 'layer_freezing':
@@ -158,6 +174,8 @@ def build_scheduler(name: str,
return ConstantWithWarmupScheduler(**scheduler_config)
elif name == 'cosine_with_warmup':
return CosineAnnealingWithWarmupScheduler(**scheduler_config)
+ elif name == 'inv_sqrt_with_warmup':
+ return InverseSquareRootWithWarmupScheduler(**scheduler_config)
elif name == 'linear_decay_with_warmup':
return LinearWithWarmupScheduler(**scheduler_config)
else:
diff --git a/scripts/eval/README.md b/scripts/eval/README.md
index 201e61959c..ca97cc4bfb 100644
--- a/scripts/eval/README.md
+++ b/scripts/eval/README.md
@@ -31,7 +31,7 @@ You can also modify the specific benchmarks executed and their formatting by mod
### Evaluation during training
-To run evaluatio during training, download this repo, follow the instructions in `scripts/train/README.md` to perform single node pre-training and run the following commands
+To run evaluation during training, download this repo, follow the instructions in `scripts/train/README.md` to perform single node pre-training and run the following commands
```bash
@@ -45,7 +45,7 @@ You can also modify the specific benchmarks executed and their formatting by mod
ICL evaluation can be done offline via the `scripts/eval/eval.py` or during training via `scripts/train/train.py`.
-In order to do ICL evaluation you must specify a set of benchmarks you'd like to run via the `icl_tasks` key in your eval/training config. `icl_tasks` can either consist of config, or it can be a file path pointing to a locally accessible YAML config (see `scripts/eval/yamls/icl_tasks.yaml` for an example).
+In order to do ICL evaluation you must specify a set of benchmarks you'd like to run via the `icl_tasks` key in your eval/training config. `icl_tasks` can either consist of config, or it can be a file path pointing to a locally accessible YAML config (see `scripts/eval/yamls/tasks.yaml` for an example).
#### ICL task YAML format
diff --git a/scripts/eval/local_data/MODEL_GAUNTLET.md b/scripts/eval/local_data/MODEL_GAUNTLET.md
index ed30d866dd..4a0c8b93fe 100644
--- a/scripts/eval/local_data/MODEL_GAUNTLET.md
+++ b/scripts/eval/local_data/MODEL_GAUNTLET.md
@@ -263,12 +263,12 @@ Programming tasks evaluate the model's ability to understand code, write functio
- Number of few shot examples: 0
- Random baseline accuracy: 0%
36. HumanEval C++ code generation
- - Description: HumanEval C++ consists of 161 C++ programming challenges, in which the model is presented with the method signature and docstring comment for a C++ program and is expected to complete the program. We then test the resultant code’s functional correctness on a number of test input/output pairs.
+ - Description: HumanEval C++ consists of 161 C++ programming challenges, in which the model is presented with the method signature and docstring comment for a C++ program and is expected to complete the program. We then test the resultant code’s functional correctness on a number of test input/output pairs. The C++ translation of HumanEval comes from the [CodeGeex](https://huggingface.co/datasets/THUDM/humaneval-x/viewer/cpp) project.
- Year released: 2022
- Number of few shot examples: 0
- Random baseline accuracy: 0%
37. HumanEval JS code generation
- - Description: HumanEval JS consists of 164 Javscript programming challenges, in which the model is presented with the method signature and docstring comment for a Javacript program and is expected to complete the program. We then test the resultant code’s functional correctness on a number of test input/output pairs.
+ - Description: HumanEval JS consists of 164 Javscript programming challenges, in which the model is presented with the method signature and docstring comment for a Javacript program and is expected to complete the program. We then test the resultant code’s functional correctness on a number of test input/output pairs. The JS translation of HumanEval comes from the [CodeGeex](https://huggingface.co/datasets/THUDM/humaneval-x/viewer/cpp) project.
- Year released: 2022
- Number of few shot examples: 0
- Random baseline accuracy: 0%
diff --git a/scripts/misc/update_hub_code.py b/scripts/misc/update_hub_code.py
index 9fbb76977f..ee5f6935a3 100644
--- a/scripts/misc/update_hub_code.py
+++ b/scripts/misc/update_hub_code.py
@@ -14,8 +14,24 @@
from llmfoundry.utils.huggingface_hub_utils import \
edit_files_for_hf_compatibility
+_ALL_MODELS = [
+ 'mosaicml/mpt-7b',
+ 'mosaicml/mpt-7b-instruct',
+ 'mosaicml/mpt-7b-chat',
+ 'mosaicml/mpt-30b',
+ 'mosaicml/mpt-30b-chat',
+ 'mosaicml/mpt-30b-instruct',
+ 'mosaicml/mpt-7b-8k',
+ 'mosaicml/mpt-7b-8k-instruct',
+ 'mosaicml/mpt-7b-8k-chat',
+ 'mosaicml/mpt-7b-storywriter',
+]
+
def main(hf_repos_for_upload: List[str]):
+ if len(hf_repos_for_upload) == 1 and hf_repos_for_upload[0] == 'all':
+ hf_repos_for_upload = _ALL_MODELS
+
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime('%B %d, %Y %H:%M:%S')
@@ -61,7 +77,7 @@ def main(hf_repos_for_upload: List[str]):
create_pr=True,
)
- print(f'PR opened: {result}')
+ print(f'PR opened: {result}\n')
if __name__ == '__main__':
diff --git a/scripts/train/README.md b/scripts/train/README.md
index f10fdf59f0..4c706dc040 100644
--- a/scripts/train/README.md
+++ b/scripts/train/README.md
@@ -5,14 +5,15 @@ This README walks through pretraining and finetuning a large language model usin
#### Table of Contents
1. [Part 1: LLM Pretraining](#llmpretraining)
1. [Installation](#installation)
- 2. [Dataset Preparation](#datasetpreparation)
- 3. [How to start single and multi-node pretraining](#howtostartpretraining)
-2. [Part 2: LLM Finetuning](#llmfinetuning)
+ 1. [Dataset Preparation](#datasetpreparation)
+ 1. [How to start single and multi-node pretraining](#howtostartpretraining)
+1. [Part 2: LLM Finetuning](#llmfinetuning)
1. [Using a dataset on the HuggingFace Hub](#hfdataset)
- 2. [Using a local dataset](#localdataset)
- 3. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
-3. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
-4. [FAQ: Optimizing Performance](#optimizingperformance)
+ 1. [Using a local dataset](#localdataset)
+ 1. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
+1. [Using Flash Attention](#flashattention)
+1. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
+1. [FAQ: Optimizing Performance](#optimizingperformance)
# Part 1: LLM Pretraining
@@ -332,6 +333,53 @@ train_loader:
...
```
+# Using Flash Attention
+
+Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.
+
+For MPT, you can specify Flash Attention in your YAML like so:
+```yaml
+model:
+ name: mpt_causal_lm
+ ...
+ attn_config:
+ # Will use either V1 or V2 depending on what is installed
+ # "triton" will use the Triton implementation
+ attn_impl: flash
+ ...
+```
+
+If loading MPT from the HuggingFace Hub, you can specify Flash Attention in your YAML like so:
+```yaml
+model:
+ name: hf_causal_lm
+ pretrained_model_name_or_path: mosaicml/mpt-7b
+ ...
+ config_overrides:
+ # Will use either V1 or V2 depending on what is installed
+ # "triton" will use the Triton implementation
+ attn_config:
+ attn_impl: flash
+ ...
+```
+
+For any HuggingFace model that supports Flash Attention (e.g. Llama and Mistral), you can specify Flash Attention in your YAML like so:
+```yaml
+model:
+ name: hf_causal_lm
+ use_flash_attention_2: True # Will be automatically set to True if Flash Attention V2 is installed and the model supports it
+ ...
+```
+HuggingFace models currently only support Flash Attention V2.
+
+For Llama specifically, we have another option if you would like to use the Triton implementation of Flash Attention. You can specify this in your YAML like so:
+```yaml
+model:
+ name: hf_causal_lm
+ pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf
+ attention_patch_type: triton
+ ...
+```
# FAQ: How many GPUs do I need to train a LLM?
This is a complicated question in general, but if we assume that you are using FSDP with `FULL_SHARD`,
diff --git a/scripts/train/benchmarking/README.md b/scripts/train/benchmarking/README.md
index 7164e93bd8..c3c8bc1c74 100644
--- a/scripts/train/benchmarking/README.md
+++ b/scripts/train/benchmarking/README.md
@@ -69,176 +69,218 @@ Our microbatching engine enables microbatch sizes that do not divde Global Batch
[comment]: # (TODO: Update tables with torch 2.0 after next Composer release)
+## H100 80GB BF16
+| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| 70b | 2048 | 64 | h100_80gb | 42.57 | 56.76 | 421 | 8 | 4 | 2048 | 32 | 66523 | 1039 | 4194304 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 64862437376 |
+| 70b | 2048 | 32 | h100_80gb | 36.15 | 48.2 | 357 | 2 | 16 | 1024 | 13 | 28242 | 882 | 2097152 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 64862437376 |
+| 30b | 8192 | 8 | h100_80gb | 29.92 | 39.9 | 296 | 1 | 21 | 168 | 1 | 11072 | 1384 | 1376256 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 30019254272 |
+| 30b | 4096 | 8 | h100_80gb | 35.86 | 47.81 | 354 | 1 | 21 | 168 | 3 | 14419 | 1802 | 688128 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29989894144 |
+| 30b | 2048 | 32 | h100_80gb | 43.92 | 58.57 | 434 | 14 | 3 | 1344 | 36 | 73860 | 2308 | 2752512 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 |
+| 30b | 2048 | 16 | h100_80gb | 43.07 | 57.42 | 426 | 10 | 3 | 480 | 17 | 36209 | 2263 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 |
+| 30b | 2048 | 8 | h100_80gb | 38.11 | 50.82 | 377 | 3 | 21 | 504 | 7 | 16022 | 2002 | 1032192 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 |
+| 30b | 1024 | 8 | h100_80gb | 38.76 | 51.68 | 383 | 6 | 21 | 1008 | 16 | 16672 | 2084 | 1032192 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29967874048 |
+| 13b | 32768 | 8 | h100_80gb | 31.68 | 42.24 | 313 | 1 | 3 | 24 | 0 | 15812 | 1976 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 13011240960 |
+| 13b | 16384 | 8 | h100_80gb | 35.55 | 47.4 | 351 | 3 | 3 | 72 | 1 | 23881 | 2985 | 1179648 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12927354880 |
+| 13b | 4096 | 8 | h100_80gb | 41.6 | 55.47 | 411 | 10 | 3 | 240 | 9 | 37740 | 4717 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12864440320 |
+| 13b | 2048 | 64 | h100_80gb | 39.86 | 39.86 | 394 | 2 | 1 | 128 | 150 | 307209 | 4800 | 262144 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 12853954560 |
+| 13b | 2048 | 32 | h100_80gb | 39.95 | 39.95 | 395 | 2 | 1 | 64 | 75 | 153960 | 4811 | 131072 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 12853954560 |
+| 13b | 2048 | 16 | h100_80gb | 39.58 | 39.58 | 391 | 2 | 1 | 32 | 37 | 76280 | 4767 | 65536 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 12853954560 |
+| 13b | 2048 | 8 | h100_80gb | 39.79 | 39.79 | 393 | 2 | 1 | 16 | 18 | 38336 | 4792 | 32768 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 12853954560 |
+| 13b | 1024 | 8 | h100_80gb | 44.27 | 59.03 | 438 | 40 | 3 | 960 | 42 | 44019 | 5502 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12848711680 |
+| 7b | 65536 | 8 | h100_80gb | 28.59 | 38.13 | 282 | 1 | 2 | 16 | 0 | 15654 | 1956 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6918905856 |
+| 7b | 32768 | 8 | h100_80gb | 30.94 | 41.25 | 306 | 2 | 2 | 32 | 0 | 26550 | 3318 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6784688128 |
+| 7b | 8192 | 8 | h100_80gb | 37.14 | 49.52 | 367 | 8 | 2 | 128 | 6 | 55481 | 6935 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6684024832 |
+| 7b | 4096 | 8 | h100_80gb | 40.42 | 53.9 | 399 | 16 | 2 | 256 | 16 | 68893 | 8611 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6667247616 |
+| 7b | 2048 | 8 | h100_80gb | 46.44 | 46.44 | 459 | 6 | 1 | 48 | 41 | 85144 | 10643 | 98304 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 6658859008 |
+| 7b | 1024 | 8 | h100_80gb | 42.83 | 57.11 | 423 | 64 | 2 | 1024 | 79 | 81628 | 10203 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6654664704 |
+| 3b | 65536 | 8 | h100_80gb | 26.81 | 35.74 | 265 | 1 | 2 | 16 | 0 | 26099 | 3262 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 2814366720 |
+| 3b | 32768 | 8 | h100_80gb | 28.84 | 38.46 | 285 | 3 | 6 | 144 | 1 | 46984 | 5873 | 4718592 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 2730480640 |
+| 3b | 16384 | 8 | h100_80gb | 36.34 | 36.34 | 359 | 1 | 6 | 48 | 5 | 89223 | 11152 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2688537600 |
+| 3b | 8192 | 8 | h100_80gb | 40.31 | 40.31 | 398 | 3 | 6 | 144 | 16 | 132626 | 16578 | 1179648 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2667566080 |
+| 3b | 4096 | 8 | h100_80gb | 42.31 | 42.31 | 418 | 5 | 6 | 240 | 40 | 167712 | 20964 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2657080320 |
+| 3b | 2048 | 64 | h100_80gb | 40.8 | 40.8 | 403 | 6 | 3 | 1152 | 703 | 1441663 | 22525 | 2359296 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2651837440 |
+| 3b | 2048 | 32 | h100_80gb | 41.7 | 41.7 | 412 | 6 | 3 | 576 | 359 | 736701 | 23021 | 1179648 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2651837440 |
+| 3b | 2048 | 16 | h100_80gb | 43.73 | 43.73 | 432 | 10 | 3 | 480 | 188 | 386285 | 24142 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2651837440 |
+| 3b | 1024 | 8 | h100_80gb | 46.2 | 46.2 | 457 | 20 | 6 | 960 | 211 | 216369 | 27046 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2649216000 |
+| 3b | 512 | 8 | h100_80gb | 46.32 | 46.32 | 458 | 40 | 6 | 1920 | 436 | 223721 | 27965 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2647905280 |
+| 1b | 65536 | 8 | h100_80gb | 26.34 | 35.12 | 260 | 1 | 2 | 16 | 0 | 44050 | 5506 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 1445974016 |
+| 1b | 32768 | 8 | h100_80gb | 33.54 | 33.54 | 331 | 1 | 4 | 32 | 2 | 96203 | 12025 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1378865152 |
+| 1b | 16384 | 8 | h100_80gb | 35.22 | 35.22 | 348 | 2 | 4 | 64 | 9 | 157194 | 19649 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1345310720 |
+| 1b | 8192 | 8 | h100_80gb | 37.73 | 37.73 | 373 | 3 | 4 | 96 | 28 | 233256 | 29157 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1328533504 |
+| 1b | 4096 | 8 | h100_80gb | 40.26 | 40.26 | 398 | 7 | 4 | 224 | 75 | 308282 | 38535 | 917504 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1320144896 |
+| 1b | 2048 | 64 | h100_80gb | 40.85 | 40.85 | 404 | 20 | 1 | 1280 | 1387 | 2841754 | 44402 | 2621440 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1315950592 |
+| 1b | 2048 | 32 | h100_80gb | 41.52 | 41.52 | 410 | 20 | 1 | 640 | 705 | 1444183 | 45130 | 1310720 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1315950592 |
+| 1b | 2048 | 16 | h100_80gb | 42.36 | 42.36 | 419 | 20 | 1 | 320 | 359 | 736596 | 46037 | 655360 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1315950592 |
+| 1b | 2048 | 8 | h100_80gb | 41.82 | 41.82 | 413 | 14 | 1 | 112 | 177 | 363645 | 45455 | 229376 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1315950592 |
+| 1b | 1024 | 8 | h100_80gb | 41.95 | 41.95 | 415 | 18 | 4 | 576 | 382 | 391287 | 48910 | 589824 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1313853440 |
+| 1b | 512 | 8 | h100_80gb | 43.21 | 43.21 | 427 | 56 | 4 | 1792 | 816 | 418201 | 52275 | 917504 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1312804864 |
+| 760m | 32768 | 8 | h100_80gb | 31.84 | 31.84 | 315 | 1 | 2 | 16 | 3 | 130333 | 16291 | 524288 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 807656448 |
+| 760m | 16384 | 8 | h100_80gb | 33.57 | 33.57 | 332 | 3 | 2 | 48 | 13 | 222521 | 27815 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 782490624 |
+| 760m | 8192 | 8 | h100_80gb | 34.84 | 34.84 | 344 | 6 | 2 | 96 | 40 | 334602 | 41825 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 769907712 |
+| 760m | 4096 | 8 | h100_80gb | 35.83 | 35.83 | 354 | 12 | 2 | 192 | 108 | 443674 | 55459 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 763616256 |
+| 760m | 2048 | 32 | h100_80gb | 37.57 | 37.57 | 371 | 24 | 1 | 768 | 1062 | 2175091 | 67971 | 1572864 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 760470528 |
+| 760m | 2048 | 16 | h100_80gb | 37.89 | 37.89 | 374 | 24 | 1 | 384 | 535 | 1096819 | 68551 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 760470528 |
+| 760m | 2048 | 8 | h100_80gb | 34.9 | 34.9 | 345 | 24 | 2 | 384 | 246 | 505177 | 63147 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 760470528 |
+| 760m | 1024 | 8 | h100_80gb | 39.76 | 39.76 | 393 | 48 | 2 | 768 | 613 | 628648 | 78581 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 758897664 |
+| 760m | 512 | 8 | h100_80gb | 40.42 | 40.42 | 399 | 96 | 2 | 1536 | 1308 | 669998 | 83749 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 758111232 |
+
+## H100 80GB FP8
+| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| 3b | 32768 | 8 | h100_80gb | 14.38 | 19.18 | 284 | 3 | 6 | 144 | 1 | 46853 | 5856 | 4718592 | amp_fp8 | DEFAULT | FULL_SHARD | True | False | 2730480640 |
+| 3b | 8192 | 8 | h100_80gb | 23.28 | 23.28 | 460 | 3 | 6 | 144 | 18 | 153174 | 19146 | 1179648 | amp_fp8 | DEFAULT | FULL_SHARD | False | False | 2667566080 |
+| 3b | 2048 | 8 | h100_80gb | 27.7 | 27.7 | 548 | 10 | 6 | 480 | 119 | 244692 | 30586 | 983040 | amp_fp8 | DEFAULT | FULL_SHARD | False | False | 2651837440 |
+| 3b | 512 | 8 | h100_80gb | 30.25 | 30.25 | 598 | 40 | 6 | 1920 | 570 | 292217 | 36527 | 983040 | amp_fp8 | DEFAULT | FULL_SHARD | False | False | 2647905280 |
+| 1b | 32768 | 8 | h100_80gb | 17.55 | 17.55 | 347 | 1 | 4 | 32 | 3 | 100643 | 12580 | 1048576 | amp_fp8 | DEFAULT | FULL_SHARD | False | False | 1378865152 |
+| 1b | 8192 | 8 | h100_80gb | 20.71 | 20.71 | 409 | 2 | 4 | 64 | 31 | 256087 | 32010 | 524288 | amp_fp8 | DEFAULT | FULL_SHARD | False | False | 1328533504 |
+| 1b | 512 | 8 | h100_80gb | 29.06 | 29.06 | 575 | 56 | 4 | 1792 | 1098 | 562523 | 70315 | 917504 | amp_fp8 | DEFAULT | FULL_SHARD | False | False | 1312804864 |
+
## A100 80GB with 1600 Gbps node-node interconnect (RoCE)
-| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
-| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
-| 70b | 2048 | 64 | a100_80gb | 53.33 | 71.1 | 8 | 4 | 2048 | 12 | 26274 | 410 | 4194304 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
-| 70b | 2048 | 32 | a100_80gb | 48.56 | 64.75 | 2 | 16 | 1024 | 5 | 11962 | 373 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
-| 30b | 8192 | 8 | a100_80gb | 42.66 | 56.89 | 1 | 21 | 168 | 0 | 4977 | 622 | 1376256 | bf16 | PURE | FULL_SHARD | True | False | 30019254272 |
-| 30b | 4096 | 8 | a100_80gb | 49.12 | 65.49 | 1 | 21 | 168 | 1 | 6227 | 778 | 688128 | bf16 | PURE | FULL_SHARD | True | False | 29989894144 |
-| 30b | 2048 | 64 | a100_80gb | 52.93 | 70.57 | 16 | 3 | 3072 | 27 | 56126 | 876 | 6291456 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
-| 30b | 2048 | 32 | a100_80gb | 53.48 | 71.3 | 14 | 3 | 1344 | 13 | 28353 | 886 | 2752512 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
-| 30b | 2048 | 16 | a100_80gb | 53.4 | 71.2 | 10 | 3 | 480 | 6 | 14157 | 884 | 983040 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
-| 30b | 2048 | 8 | a100_80gb | 47.57 | 63.43 | 3 | 21 | 504 | 3 | 6305 | 788 | 1032192 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
-| 30b | 1024 | 8 | a100_80gb | 51.69 | 68.92 | 6 | 21 | 1008 | 6 | 7010 | 876 | 1032192 | bf16 | PURE | FULL_SHARD | True | False | 29967874048 |
-| 30b | 512 | 8 | a100_80gb | 49.23 | 65.63 | 12 | 21 | 2016 | 13 | 6754 | 844 | 1032192 | bf16 | PURE | FULL_SHARD | True | False | 29964204032 |
-| 13b | 32768 | 8 | a100_80gb | 49.53 | 66.04 | 1 | 3 | 24 | 0 | 7795 | 974 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 13011240960 |
-| 13b | 16384 | 8 | a100_80gb | 51.71 | 68.94 | 3 | 3 | 72 | 0 | 10953 | 1369 | 1179648 | bf16 | PURE | FULL_SHARD | True | False | 12927354880 |
-| 13b | 8192 | 8 | a100_80gb | 52.83 | 70.44 | 5 | 3 | 120 | 1 | 13531 | 1691 | 983040 | bf16 | PURE | FULL_SHARD | True | False | 12885411840 |
-| 13b | 4096 | 8 | a100_80gb | 53.62 | 71.5 | 10 | 3 | 240 | 3 | 15339 | 1917 | 983040 | bf16 | PURE | FULL_SHARD | True | False | 12864440320 |
-| 13b | 2048 | 64 | a100_80gb | 52.51 | 70.01 | 32 | 1 | 2048 | 62 | 127624 | 1994 | 4194304 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
-| 13b | 2048 | 32 | a100_80gb | 52.86 | 70.48 | 32 | 1 | 1024 | 31 | 64241 | 2007 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
-| 13b | 2048 | 16 | a100_80gb | 53.14 | 70.86 | 24 | 1 | 384 | 15 | 32291 | 2018 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
-| 13b | 2048 | 8 | a100_80gb | 54.38 | 72.51 | 20 | 3 | 480 | 8 | 16522 | 2065 | 983040 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
-| 13b | 1024 | 8 | a100_80gb | 55.23 | 73.63 | 40 | 3 | 960 | 16 | 17315 | 2164 | 983040 | bf16 | PURE | FULL_SHARD | True | False | 12848711680 |
-| 13b | 512 | 8 | a100_80gb | 54.99 | 73.32 | 80 | 3 | 1920 | 34 | 17521 | 2190 | 983040 | bf16 | PURE | FULL_SHARD | True | False | 12846090240 |
-| 7b | 65536 | 8 | a100_80gb | 42.61 | 56.82 | 1 | 2 | 16 | 0 | 7355 | 919 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 6918905856 |
-| 7b | 32768 | 8 | a100_80gb | 48.18 | 64.24 | 2 | 2 | 32 | 0 | 13035 | 1629 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 6784688128 |
-| 7b | 16384 | 8 | a100_80gb | 49.5 | 66.0 | 4 | 2 | 64 | 1 | 18698 | 2337 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 6717579264 |
-| 7b | 8192 | 8 | a100_80gb | 50.71 | 67.62 | 8 | 2 | 128 | 2 | 23887 | 2985 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 6684024832 |
-| 7b | 4096 | 8 | a100_80gb | 52.05 | 69.4 | 16 | 2 | 256 | 6 | 27973 | 3496 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 6667247616 |
-| 7b | 2048 | 64 | a100_80gb | 50.8 | 67.73 | 32 | 1 | 2048 | 114 | 234932 | 3670 | 4194304 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
-| 7b | 2048 | 32 | a100_80gb | 51.16 | 68.22 | 32 | 1 | 1024 | 57 | 118310 | 3697 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
-| 7b | 2048 | 16 | a100_80gb | 51.59 | 68.79 | 32 | 1 | 512 | 29 | 59653 | 3728 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
-| 7b | 2048 | 8 | a100_80gb | 52.92 | 70.56 | 32 | 2 | 512 | 14 | 30596 | 3824 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
-| 7b | 1024 | 8 | a100_80gb | 53.66 | 71.55 | 64 | 2 | 1024 | 31 | 32243 | 4030 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 6654664704 |
-| 7b | 512 | 8 | a100_80gb | 53.5 | 71.34 | 128 | 2 | 2048 | 64 | 32794 | 4099 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 6652567552 |
-| 3b | 65536 | 8 | a100_80gb | 46.17 | 61.57 | 1 | 2 | 16 | 0 | 14174 | 1771 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 2814366720 |
-| 3b | 32768 | 8 | a100_80gb | 46.73 | 62.31 | 3 | 6 | 144 | 0 | 24003 | 3000 | 4718592 | bf16 | PURE | FULL_SHARD | True | False | 2730480640 |
-| 3b | 16384 | 8 | a100_80gb | 57.29 | 57.29 | 1 | 6 | 48 | 2 | 44356 | 5544 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 2688537600 |
-| 3b | 8192 | 8 | a100_80gb | 58.68 | 58.68 | 3 | 6 | 144 | 7 | 60883 | 7610 | 1179648 | bf16 | PURE | FULL_SHARD | False | False | 2667566080 |
-| 3b | 4096 | 8 | a100_80gb | 59.51 | 59.51 | 5 | 6 | 240 | 18 | 74388 | 9298 | 983040 | bf16 | PURE | FULL_SHARD | False | False | 2657080320 |
-| 3b | 2048 | 64 | a100_80gb | 58.36 | 58.36 | 12 | 3 | 2304 | 317 | 650175 | 10158 | 4718592 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
-| 3b | 2048 | 32 | a100_80gb | 59.22 | 59.22 | 12 | 3 | 1152 | 161 | 329856 | 10308 | 2359296 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
-| 3b | 2048 | 16 | a100_80gb | 59.08 | 59.08 | 10 | 3 | 480 | 80 | 164543 | 10283 | 983040 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
-| 3b | 2048 | 8 | a100_80gb | 59.77 | 59.77 | 10 | 6 | 480 | 40 | 83230 | 10403 | 983040 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
-| 3b | 1024 | 8 | a100_80gb | 61.56 | 61.56 | 20 | 6 | 960 | 88 | 90906 | 11363 | 983040 | bf16 | PURE | FULL_SHARD | False | False | 2649216000 |
-| 3b | 512 | 8 | a100_80gb | 62.09 | 62.09 | 40 | 6 | 1920 | 184 | 94553 | 11819 | 983040 | bf16 | PURE | FULL_SHARD | False | False | 2647905280 |
-| 1b | 65536 | 8 | a100_80gb | 45.29 | 60.39 | 1 | 2 | 16 | 0 | 23885 | 2985 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 1445974016 |
-| 1b | 32768 | 8 | a100_80gb | 56.02 | 56.02 | 1 | 4 | 32 | 1 | 50657 | 6332 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1378865152 |
-| 1b | 16384 | 8 | a100_80gb | 55.84 | 55.84 | 2 | 4 | 64 | 4 | 78591 | 9823 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1345310720 |
-| 1b | 8192 | 8 | a100_80gb | 56.38 | 56.38 | 3 | 4 | 96 | 13 | 109915 | 13739 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 1328533504 |
-| 1b | 4096 | 8 | a100_80gb | 58.3 | 58.3 | 7 | 4 | 224 | 34 | 140767 | 17595 | 917504 | bf16 | PURE | FULL_SHARD | False | False | 1320144896 |
-| 1b | 2048 | 64 | a100_80gb | 56.67 | 56.67 | 20 | 1 | 1280 | 606 | 1243103 | 19423 | 2621440 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
-| 1b | 2048 | 32 | a100_80gb | 56.74 | 56.74 | 20 | 1 | 640 | 303 | 622285 | 19446 | 1310720 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
-| 1b | 2048 | 16 | a100_80gb | 57.47 | 57.47 | 20 | 1 | 320 | 153 | 315117 | 19694 | 655360 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
-| 1b | 2048 | 8 | a100_80gb | 59.16 | 59.16 | 14 | 4 | 448 | 79 | 162214 | 20276 | 917504 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
-| 1b | 1024 | 8 | a100_80gb | 58.98 | 58.98 | 18 | 4 | 576 | 169 | 173458 | 21682 | 589824 | bf16 | PURE | FULL_SHARD | False | False | 1313853440 |
-| 1b | 512 | 8 | a100_80gb | 60.38 | 60.38 | 56 | 4 | 1792 | 359 | 184268 | 23033 | 917504 | bf16 | PURE | FULL_SHARD | False | False | 1312804864 |
-| 760m | 65536 | 8 | a100_80gb | 45.48 | 60.64 | 1 | 2 | 16 | 0 | 33252 | 4156 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 857988096 |
-| 760m | 32768 | 8 | a100_80gb | 54.48 | 54.48 | 1 | 2 | 16 | 2 | 70305 | 8788 | 524288 | bf16 | PURE | FULL_SHARD | False | False | 807656448 |
-| 760m | 16384 | 8 | a100_80gb | 55.21 | 55.21 | 3 | 2 | 48 | 7 | 115383 | 14422 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 782490624 |
-| 760m | 8192 | 8 | a100_80gb | 55.13 | 55.13 | 6 | 2 | 96 | 20 | 166928 | 20866 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 769907712 |
-| 760m | 4096 | 8 | a100_80gb | 55.2 | 55.2 | 12 | 2 | 192 | 52 | 215501 | 26937 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 763616256 |
-| 760m | 2048 | 64 | a100_80gb | 51.82 | 51.82 | 24 | 1 | 1536 | 923 | 1892166 | 29565 | 3145728 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
-| 760m | 2048 | 32 | a100_80gb | 53.27 | 53.27 | 24 | 1 | 768 | 474 | 972497 | 30390 | 1572864 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
-| 760m | 2048 | 16 | a100_80gb | 53.56 | 53.56 | 24 | 1 | 384 | 238 | 488871 | 30554 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
-| 760m | 2048 | 8 | a100_80gb | 55.67 | 55.67 | 24 | 2 | 384 | 124 | 254104 | 31763 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
-| 760m | 1024 | 8 | a100_80gb | 55.98 | 55.98 | 48 | 2 | 768 | 272 | 279108 | 34888 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 758897664 |
-| 760m | 512 | 8 | a100_80gb | 56.2 | 56.2 | 96 | 2 | 1536 | 573 | 293755 | 36719 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 758111232 |
-| 350m | 65536 | 8 | a100_80gb | 52.39 | 52.39 | 1 | 2 | 16 | 0 | 59835 | 7479 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 420997120 |
-| 350m | 32768 | 8 | a100_80gb | 47.45 | 47.45 | 2 | 2 | 32 | 3 | 98793 | 12349 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 387442688 |
-| 350m | 16384 | 8 | a100_80gb | 53.01 | 53.01 | 4 | 2 | 64 | 11 | 187535 | 23441 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 370665472 |
-| 350m | 8192 | 8 | a100_80gb | 53.21 | 53.21 | 8 | 2 | 128 | 35 | 289398 | 36174 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 362276864 |
-| 350m | 4096 | 8 | a100_80gb | 52.46 | 52.46 | 16 | 2 | 256 | 95 | 390131 | 48766 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 358082560 |
-| 350m | 2048 | 64 | a100_80gb | 47.76 | 47.76 | 32 | 1 | 2048 | 1699 | 3480601 | 54384 | 4194304 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
-| 350m | 2048 | 32 | a100_80gb | 48.58 | 48.58 | 32 | 1 | 1024 | 864 | 1770287 | 55321 | 2097152 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
-| 350m | 2048 | 16 | a100_80gb | 50.53 | 50.53 | 32 | 1 | 512 | 449 | 920605 | 57537 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
-| 350m | 2048 | 8 | a100_80gb | 51.73 | 51.73 | 32 | 2 | 512 | 230 | 471290 | 58911 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
-| 350m | 1024 | 8 | a100_80gb | 51.28 | 51.28 | 64 | 2 | 1024 | 514 | 526393 | 65799 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 354936832 |
-| 350m | 512 | 8 | a100_80gb | 51.18 | 51.18 | 128 | 2 | 2048 | 1095 | 560858 | 70107 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 354412544 |
-| 125m | 65536 | 8 | a100_80gb | 54.31 | 54.31 | 1 | 2 | 16 | 2 | 163472 | 20434 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 174070272 |
-| 125m | 32768 | 8 | a100_80gb | 53.15 | 53.15 | 2 | 2 | 32 | 8 | 293685 | 36710 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 148904448 |
-| 125m | 16384 | 8 | a100_80gb | 51.58 | 51.58 | 4 | 2 | 64 | 29 | 489578 | 61197 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 136321536 |
-| 125m | 8192 | 8 | a100_80gb | 49.18 | 49.18 | 8 | 2 | 128 | 88 | 727986 | 90998 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 130030080 |
-| 125m | 4096 | 8 | a100_80gb | 46.62 | 46.62 | 16 | 2 | 256 | 233 | 958343 | 119792 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 126884352 |
-| 125m | 2048 | 64 | a100_80gb | 40.77 | 40.77 | 32 | 1 | 2048 | 4063 | 8321727 | 130026 | 4194304 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
-| 125m | 2048 | 32 | a100_80gb | 41.22 | 41.22 | 32 | 1 | 1024 | 2053 | 4206041 | 131438 | 2097152 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
-| 125m | 2048 | 16 | a100_80gb | 41.92 | 41.92 | 32 | 1 | 512 | 1044 | 2139036 | 133689 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
-| 125m | 2048 | 8 | a100_80gb | 44.04 | 44.04 | 32 | 2 | 512 | 548 | 1123506 | 140438 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
-| 125m | 1024 | 8 | a100_80gb | 43.25 | 43.25 | 64 | 2 | 1024 | 1225 | 1254561 | 156820 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 124525056 |
-| 125m | 512 | 8 | a100_80gb | 42.54 | 42.54 | 128 | 2 | 2048 | 2587 | 1325030 | 165628 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 124131840 |
+| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| 70b | 2048 | 64 | a100_80gb | 53.33 | 71.1 | 166 | 8 | 4 | 2048 | 12 | 26274 | 410 | 4194304 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
+| 70b | 2048 | 32 | a100_80gb | 48.56 | 64.75 | 151 | 2 | 16 | 1024 | 5 | 11962 | 373 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
+| 30b | 8192 | 8 | a100_80gb | 39.38 | 52.5 | 122 | 1 | 21 | 168 | 0 | 4594 | 574 | 1376256 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 30019254272 |
+| 30b | 4096 | 8 | a100_80gb | 51.37 | 68.49 | 160 | 1 | 21 | 168 | 1 | 6513 | 814 | 688128 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29989894144 |
+| 30b | 2048 | 8 | a100_80gb | 55.3 | 73.74 | 172 | 3 | 21 | 504 | 3 | 7330 | 916 | 1032192 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29975214080 |
+| 30b | 1024 | 8 | a100_80gb | 55.82 | 74.43 | 174 | 6 | 21 | 1008 | 7 | 7571 | 946 | 1032192 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29967874048 |
+| 30b | 512 | 8 | a100_80gb | 56.4 | 75.2 | 175 | 12 | 21 | 2016 | 15 | 7739 | 967 | 1032192 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 29964204032 |
+| 13b | 32768 | 8 | a100_80gb | 51.69 | 68.92 | 161 | 1 | 3 | 24 | 0 | 8134 | 1016 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 13011240960 |
+| 13b | 16384 | 8 | a100_80gb | 54.07 | 72.1 | 168 | 3 | 3 | 72 | 0 | 11454 | 1431 | 1179648 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12927354880 |
+| 13b | 8192 | 8 | a100_80gb | 56.07 | 74.76 | 174 | 5 | 3 | 120 | 1 | 14362 | 1795 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12885411840 |
+| 13b | 4096 | 8 | a100_80gb | 57.62 | 76.82 | 179 | 10 | 3 | 240 | 4 | 16482 | 2060 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12864440320 |
+| 13b | 2048 | 8 | a100_80gb | 59.57 | 59.57 | 185 | 2 | 3 | 48 | 8 | 18097 | 2262 | 98304 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 12853954560 |
+| 13b | 1024 | 8 | a100_80gb | 59.48 | 79.3 | 185 | 40 | 3 | 960 | 18 | 18647 | 2330 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 12848711680 |
+| 7b | 65536 | 8 | a100_80gb | 46.97 | 62.63 | 146 | 1 | 2 | 16 | 0 | 8108 | 1013 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6918905856 |
+| 7b | 32768 | 8 | a100_80gb | 49.46 | 65.94 | 154 | 2 | 2 | 32 | 0 | 13382 | 1672 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6784688128 |
+| 7b | 16384 | 8 | a100_80gb | 51.96 | 69.28 | 162 | 4 | 2 | 64 | 1 | 19629 | 2453 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6717579264 |
+| 7b | 8192 | 8 | a100_80gb | 54.47 | 72.62 | 169 | 8 | 2 | 128 | 3 | 25655 | 3206 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6684024832 |
+| 7b | 4096 | 8 | a100_80gb | 54.84 | 73.12 | 171 | 16 | 2 | 256 | 7 | 29472 | 3684 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6667247616 |
+| 7b | 2048 | 8 | a100_80gb | 64.23 | 64.23 | 200 | 6 | 2 | 96 | 18 | 37130 | 4641 | 196608 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 6658859008 |
+| 7b | 1024 | 8 | a100_80gb | 58.01 | 77.35 | 180 | 64 | 2 | 1024 | 34 | 34857 | 4357 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 6654664704 |
+| 3b | 65536 | 8 | a100_80gb | 46.05 | 61.41 | 143 | 1 | 2 | 16 | 0 | 14137 | 1767 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 2814366720 |
+| 3b | 32768 | 8 | a100_80gb | 47.18 | 62.91 | 147 | 3 | 6 | 144 | 0 | 24235 | 3029 | 4718592 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 2730480640 |
+| 3b | 16384 | 8 | a100_80gb | 57.13 | 57.13 | 178 | 1 | 6 | 48 | 2 | 44233 | 5529 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2688537600 |
+| 3b | 8192 | 8 | a100_80gb | 59.34 | 59.34 | 185 | 3 | 6 | 144 | 7 | 61567 | 7695 | 1179648 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2667566080 |
+| 3b | 4096 | 8 | a100_80gb | 60.53 | 60.53 | 188 | 5 | 6 | 240 | 18 | 75658 | 9457 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2657080320 |
+| 3b | 2048 | 8 | a100_80gb | 62.11 | 62.11 | 193 | 10 | 2 | 160 | 42 | 86491 | 10811 | 327680 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2651837440 |
+| 3b | 1024 | 8 | a100_80gb | 62.73 | 62.73 | 195 | 20 | 6 | 960 | 90 | 92643 | 11580 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2649216000 |
+| 3b | 512 | 8 | a100_80gb | 63.71 | 63.71 | 198 | 40 | 6 | 1920 | 189 | 97019 | 12127 | 983040 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 2647905280 |
+| 1b | 65536 | 8 | a100_80gb | 46.18 | 61.57 | 144 | 1 | 2 | 16 | 0 | 24353 | 3044 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 1445974016 |
+| 1b | 32768 | 8 | a100_80gb | 55.52 | 55.52 | 173 | 1 | 4 | 32 | 1 | 50207 | 6275 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1378865152 |
+| 1b | 16384 | 8 | a100_80gb | 56.6 | 56.6 | 176 | 2 | 4 | 64 | 4 | 79650 | 9956 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1345310720 |
+| 1b | 8192 | 8 | a100_80gb | 56.69 | 56.69 | 176 | 3 | 4 | 96 | 13 | 110516 | 13814 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1328533504 |
+| 1b | 4096 | 8 | a100_80gb | 59.0 | 59.0 | 184 | 7 | 4 | 224 | 34 | 142457 | 17807 | 917504 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1320144896 |
+| 1b | 2048 | 8 | a100_80gb | 59.86 | 59.86 | 186 | 14 | 4 | 448 | 80 | 164109 | 20513 | 917504 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1315950592 |
+| 1b | 1024 | 8 | a100_80gb | 60.15 | 60.15 | 187 | 18 | 4 | 576 | 172 | 176898 | 22112 | 589824 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1313853440 |
+| 1b | 512 | 8 | a100_80gb | 60.68 | 60.68 | 189 | 56 | 4 | 1792 | 361 | 185186 | 23148 | 917504 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 1312804864 |
+| 760m | 65536 | 8 | a100_80gb | 45.34 | 60.45 | 141 | 1 | 2 | 16 | 0 | 33150 | 4143 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | True | False | 857988096 |
+| 760m | 32768 | 8 | a100_80gb | 54.57 | 54.57 | 170 | 1 | 2 | 16 | 2 | 70417 | 8802 | 524288 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 807656448 |
+| 760m | 16384 | 8 | a100_80gb | 54.64 | 54.64 | 170 | 3 | 2 | 48 | 6 | 114198 | 14274 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 782490624 |
+| 760m | 8192 | 8 | a100_80gb | 55.31 | 55.31 | 172 | 6 | 2 | 96 | 20 | 167471 | 20933 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 769907712 |
+| 760m | 4096 | 8 | a100_80gb | 56.05 | 56.05 | 174 | 12 | 2 | 192 | 53 | 218808 | 27351 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 763616256 |
+| 760m | 2048 | 8 | a100_80gb | 56.85 | 56.85 | 177 | 24 | 2 | 384 | 126 | 259472 | 32434 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 760470528 |
+| 760m | 1024 | 8 | a100_80gb | 47.76 | 47.76 | 149 | 48 | 2 | 768 | 232 | 238122 | 29765 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 758897664 |
+| 760m | 512 | 8 | a100_80gb | 45.07 | 45.07 | 140 | 96 | 2 | 1536 | 460 | 235571 | 29446 | 786432 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 758111232 |
+| 350m | 65536 | 8 | a100_80gb | 52.7 | 52.7 | 164 | 1 | 2 | 16 | 0 | 60195 | 7524 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 420997120 |
+| 350m | 32768 | 8 | a100_80gb | 52.46 | 52.46 | 163 | 2 | 2 | 32 | 3 | 109222 | 13652 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 387442688 |
+| 350m | 16384 | 8 | a100_80gb | 53.28 | 53.28 | 166 | 4 | 2 | 64 | 11 | 188478 | 23559 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 370665472 |
+| 350m | 8192 | 8 | a100_80gb | 53.8 | 53.8 | 167 | 8 | 2 | 128 | 35 | 292559 | 36569 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 362276864 |
+| 350m | 4096 | 8 | a100_80gb | 53.31 | 53.31 | 166 | 16 | 2 | 256 | 96 | 396442 | 49555 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 358082560 |
+| 350m | 2048 | 8 | a100_80gb | 51.62 | 51.62 | 161 | 32 | 2 | 512 | 229 | 470263 | 58782 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 355985408 |
+| 350m | 1024 | 8 | a100_80gb | 50.51 | 50.51 | 157 | 64 | 2 | 1024 | 506 | 518504 | 64813 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 354936832 |
+| 350m | 512 | 8 | a100_80gb | 50.61 | 50.61 | 157 | 128 | 2 | 2048 | 1083 | 554643 | 69330 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 354412544 |
+| 125m | 65536 | 8 | a100_80gb | 54.13 | 54.13 | 168 | 1 | 2 | 16 | 2 | 162946 | 20368 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 174070272 |
+| 125m | 32768 | 8 | a100_80gb | 52.71 | 52.71 | 164 | 2 | 2 | 32 | 8 | 291256 | 36407 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 148904448 |
+| 125m | 16384 | 8 | a100_80gb | 50.61 | 50.61 | 157 | 4 | 2 | 64 | 29 | 480322 | 60040 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 136321536 |
+| 125m | 8192 | 8 | a100_80gb | 48.85 | 48.85 | 152 | 8 | 2 | 128 | 88 | 723142 | 90392 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 130030080 |
+| 125m | 4096 | 8 | a100_80gb | 46.08 | 46.08 | 143 | 16 | 2 | 256 | 231 | 947172 | 118396 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 126884352 |
+| 125m | 2048 | 8 | a100_80gb | 44.79 | 44.79 | 139 | 40 | 2 | 640 | 557 | 1142641 | 142830 | 1310720 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 125311488 |
+| 125m | 2048 | 8 | a100_80gb | 44.45 | 44.45 | 138 | 32 | 2 | 512 | 553 | 1133901 | 141737 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 125311488 |
+| 125m | 1024 | 8 | a100_80gb | 43.15 | 43.15 | 134 | 64 | 2 | 1024 | 1222 | 1251751 | 156468 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 124525056 |
+| 125m | 512 | 8 | a100_80gb | 42.56 | 42.56 | 132 | 128 | 2 | 2048 | 2588 | 1325455 | 165681 | 1048576 | amp_bf16 | DEFAULT | FULL_SHARD | False | False | 124131840 |
## A100 40GB with 1600 Gbps node-node interconnect (RoCE)
-| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
-| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
-| 70b | 2048 | 128 | a100_40gb | 48.91 | 65.21 | 4 | 1 | 512 | 23 | 48194 | 376 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
-| 70b | 2048 | 64 | a100_40gb | 35.87 | 47.82 | 2 | 1 | 128 | 8 | 17672 | 276 | 262144 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
-| 30b | 2048 | 128 | a100_40gb | 52.25 | 69.66 | 6 | 1 | 768 | 54 | 110803 | 865 | 1572864 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
-| 30b | 2048 | 32 | a100_40gb | 51.74 | 68.98 | 4 | 1 | 128 | 13 | 27431 | 857 | 262144 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
-| 13b | 8192 | 8 | a100_40gb | 43.95 | 58.6 | 1 | 16 | 128 | 1 | 11258 | 1407 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12885411840 |
-| 13b | 4096 | 8 | a100_40gb | 44.85 | 59.8 | 2 | 16 | 256 | 3 | 12830 | 1603 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12864440320 |
-| 13b | 2048 | 128 | a100_40gb | 51.93 | 69.24 | 16 | 1 | 2048 | 123 | 252444 | 1972 | 4194304 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
-| 13b | 2048 | 64 | a100_40gb | 52.04 | 69.39 | 16 | 1 | 1024 | 61 | 126479 | 1976 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
-| 13b | 2048 | 32 | a100_40gb | 52.62 | 70.16 | 14 | 1 | 448 | 31 | 63946 | 1998 | 917504 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
-| 13b | 2048 | 16 | a100_40gb | 52.5 | 70.0 | 10 | 1 | 160 | 15 | 31900 | 1993 | 327680 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
-| 13b | 2048 | 8 | a100_40gb | 43.94 | 58.58 | 4 | 16 | 512 | 6 | 13347 | 1668 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
-| 13b | 1024 | 8 | a100_40gb | 44.07 | 58.76 | 8 | 16 | 1024 | 13 | 13817 | 1727 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12848711680 |
-| 13b | 512 | 8 | a100_40gb | 44.28 | 59.04 | 16 | 16 | 2048 | 27 | 14108 | 1763 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12846090240 |
-| 7b | 16384 | 8 | a100_40gb | 47.65 | 63.53 | 1 | 4 | 32 | 1 | 17998 | 2249 | 524288 | bf16 | PURE | FULL_SHARD | True | False | 6717579264 |
-| 7b | 8192 | 8 | a100_40gb | 49.04 | 65.38 | 3 | 4 | 96 | 2 | 23098 | 2887 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6684024832 |
-| 7b | 4096 | 8 | a100_40gb | 50.11 | 66.82 | 6 | 4 | 192 | 6 | 26930 | 3366 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6667247616 |
-| 7b | 2048 | 128 | a100_40gb | 50.14 | 66.85 | 18 | 1 | 2304 | 226 | 463749 | 3623 | 4718592 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
-| 7b | 2048 | 64 | a100_40gb | 50.73 | 67.64 | 18 | 1 | 1152 | 114 | 234614 | 3665 | 2359296 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
-| 7b | 2048 | 32 | a100_40gb | 51.55 | 68.73 | 18 | 1 | 576 | 58 | 119202 | 3725 | 1179648 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
-| 7b | 2048 | 16 | a100_40gb | 50.44 | 67.26 | 16 | 1 | 256 | 28 | 58322 | 3645 | 524288 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
-| 7b | 2048 | 8 | a100_40gb | 50.92 | 67.89 | 12 | 4 | 384 | 14 | 29436 | 3679 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
-| 7b | 1024 | 8 | a100_40gb | 51.31 | 68.42 | 24 | 4 | 768 | 30 | 30833 | 3854 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6654664704 |
-| 7b | 512 | 8 | a100_40gb | 50.85 | 67.8 | 48 | 4 | 1536 | 60 | 31167 | 3895 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6652567552 |
-| 3b | 32768 | 8 | a100_40gb | 46.03 | 61.37 | 1 | 4 | 32 | 0 | 23640 | 2955 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 2730480640 |
-| 3b | 16384 | 8 | a100_40gb | 46.14 | 61.52 | 2 | 8 | 128 | 2 | 35726 | 4465 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 2688537600 |
-| 3b | 8192 | 8 | a100_40gb | 55.13 | 55.13 | 1 | 8 | 64 | 6 | 57193 | 7149 | 524288 | bf16 | PURE | FULL_SHARD | False | False | 2667566080 |
-| 3b | 4096 | 8 | a100_40gb | 56.18 | 56.18 | 2 | 8 | 128 | 17 | 70223 | 8777 | 524288 | bf16 | PURE | FULL_SHARD | False | False | 2657080320 |
-| 3b | 2048 | 128 | a100_40gb | 54.8 | 54.8 | 6 | 1 | 768 | 596 | 1220885 | 9538 | 1572864 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
-| 3b | 2048 | 64 | a100_40gb | 55.94 | 55.94 | 6 | 1 | 384 | 304 | 623167 | 9736 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
-| 3b | 2048 | 32 | a100_40gb | 56.96 | 56.96 | 6 | 1 | 192 | 154 | 317261 | 9914 | 393216 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
-| 3b | 2048 | 16 | a100_40gb | 56.02 | 56.02 | 5 | 1 | 80 | 76 | 156013 | 9750 | 163840 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
-| 3b | 2048 | 8 | a100_40gb | 57.82 | 57.82 | 5 | 8 | 320 | 39 | 80520 | 10065 | 655360 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
-| 3b | 1024 | 8 | a100_40gb | 58.14 | 58.14 | 10 | 8 | 640 | 83 | 85854 | 10731 | 655360 | bf16 | PURE | FULL_SHARD | False | False | 2649216000 |
-| 3b | 512 | 8 | a100_40gb | 59.49 | 59.49 | 20 | 8 | 1280 | 176 | 90596 | 11324 | 655360 | bf16 | PURE | FULL_SHARD | False | False | 2647905280 |
-| 1b | 32768 | 8 | a100_40gb | 45.07 | 60.1 | 1 | 4 | 32 | 1 | 40762 | 5095 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 1378865152 |
-| 1b | 16384 | 8 | a100_40gb | 55.23 | 55.23 | 1 | 8 | 64 | 4 | 77723 | 9715 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1345310720 |
-| 1b | 8192 | 8 | a100_40gb | 55.29 | 55.29 | 2 | 8 | 128 | 13 | 107799 | 13474 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1328533504 |
-| 1b | 4096 | 8 | a100_40gb | 55.85 | 55.85 | 4 | 8 | 256 | 32 | 134851 | 16856 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1320144896 |
-| 1b | 2048 | 128 | a100_40gb | 54.41 | 54.41 | 10 | 1 | 1280 | 1165 | 2386897 | 18647 | 2621440 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
-| 1b | 2048 | 64 | a100_40gb | 55.44 | 55.44 | 10 | 1 | 640 | 593 | 1216104 | 19001 | 1310720 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
-| 1b | 2048 | 32 | a100_40gb | 45.39 | 45.39 | 10 | 1 | 320 | 243 | 497782 | 15555 | 655360 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
-| 1b | 2048 | 16 | a100_40gb | 55.69 | 55.69 | 8 | 1 | 128 | 149 | 305372 | 19085 | 262144 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
-| 1b | 2048 | 8 | a100_40gb | 56.23 | 56.23 | 8 | 8 | 512 | 75 | 154171 | 19271 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
-| 1b | 1024 | 8 | a100_40gb | 57.02 | 57.02 | 16 | 8 | 1024 | 163 | 167677 | 20959 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1313853440 |
-| 1b | 512 | 8 | a100_40gb | 57.1 | 57.1 | 32 | 8 | 2048 | 340 | 174256 | 21782 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1312804864 |
-| 760m | 32768 | 8 | a100_40gb | 44.53 | 59.37 | 1 | 4 | 32 | 1 | 57464 | 7183 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 807656448 |
-| 760m | 16384 | 8 | a100_40gb | 53.26 | 53.26 | 1 | 4 | 32 | 6 | 111316 | 13914 | 524288 | bf16 | PURE | FULL_SHARD | False | False | 782490624 |
-| 760m | 8192 | 8 | a100_40gb | 53.12 | 53.12 | 3 | 4 | 96 | 19 | 160853 | 20106 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 769907712 |
-| 760m | 4096 | 8 | a100_40gb | 53.0 | 53.0 | 6 | 4 | 192 | 50 | 206909 | 25863 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 763616256 |
-| 760m | 2048 | 128 | a100_40gb | 50.73 | 50.73 | 12 | 1 | 1536 | 1808 | 3704382 | 28940 | 3145728 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
-| 760m | 2048 | 64 | a100_40gb | 51.44 | 51.44 | 12 | 1 | 768 | 917 | 1878030 | 29344 | 1572864 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
-| 760m | 2048 | 32 | a100_40gb | 51.97 | 51.97 | 12 | 1 | 384 | 463 | 948745 | 29648 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
-| 760m | 2048 | 16 | a100_40gb | 51.9 | 51.9 | 12 | 1 | 192 | 231 | 473723 | 29607 | 393216 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
-| 760m | 2048 | 8 | a100_40gb | 52.89 | 52.89 | 12 | 4 | 384 | 117 | 241389 | 30173 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
-| 760m | 1024 | 8 | a100_40gb | 53.63 | 53.63 | 24 | 4 | 768 | 261 | 267376 | 33422 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 758897664 |
-| 760m | 512 | 8 | a100_40gb | 53.47 | 53.47 | 48 | 4 | 1536 | 545 | 279504 | 34938 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 758111232 |
-| 350m | 32768 | 8 | a100_40gb | 51.55 | 51.55 | 1 | 4 | 32 | 3 | 107329 | 13416 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 387442688 |
-| 350m | 16384 | 8 | a100_40gb | 51.78 | 51.78 | 2 | 4 | 64 | 11 | 183175 | 22896 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 370665472 |
-| 350m | 8192 | 8 | a100_40gb | 51.39 | 51.39 | 4 | 4 | 128 | 34 | 279466 | 34933 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 362276864 |
-| 350m | 4096 | 8 | a100_40gb | 50.38 | 50.38 | 8 | 4 | 256 | 91 | 374670 | 46833 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 358082560 |
-| 350m | 2048 | 128 | a100_40gb | 45.61 | 45.61 | 18 | 1 | 2304 | 3245 | 6647647 | 51934 | 4718592 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
-| 350m | 2048 | 64 | a100_40gb | 46.27 | 46.27 | 18 | 1 | 1152 | 1646 | 3372118 | 52689 | 2359296 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
-| 350m | 2048 | 32 | a100_40gb | 47.26 | 47.26 | 18 | 1 | 576 | 840 | 1721978 | 53811 | 1179648 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
-| 350m | 2048 | 16 | a100_40gb | 48.66 | 48.66 | 18 | 1 | 288 | 432 | 886622 | 55413 | 589824 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
-| 350m | 2048 | 8 | a100_40gb | 49.17 | 49.17 | 16 | 4 | 512 | 218 | 447963 | 55995 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
-| 350m | 1024 | 8 | a100_40gb | 48.73 | 48.73 | 32 | 4 | 1024 | 488 | 500184 | 62523 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 354936832 |
-| 350m | 512 | 8 | a100_40gb | 48.39 | 48.39 | 64 | 4 | 2048 | 1035 | 530277 | 66284 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 354412544 |
-| 125m | 32768 | 8 | a100_40gb | 47.27 | 47.27 | 1 | 4 | 32 | 7 | 261208 | 32651 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 148904448 |
-| 125m | 16384 | 8 | a100_40gb | 46.77 | 46.77 | 2 | 3 | 48 | 27 | 443876 | 55484 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 136321536 |
-| 125m | 8192 | 8 | a100_40gb | 46.94 | 46.94 | 5 | 3 | 120 | 84 | 694868 | 86858 | 983040 | bf16 | PURE | FULL_SHARD | False | False | 130030080 |
-| 125m | 4096 | 8 | a100_40gb | 44.82 | 44.82 | 13 | 3 | 312 | 224 | 921297 | 115162 | 1277952 | bf16 | PURE | FULL_SHARD | False | False | 126884352 |
-| 125m | 2048 | 128 | a100_40gb | 38.86 | 38.86 | 26 | 1 | 3328 | 7746 | 15863837 | 123936 | 6815744 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
-| 125m | 2048 | 64 | a100_40gb | 39.27 | 39.27 | 26 | 1 | 1664 | 3913 | 8015010 | 125234 | 3407872 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
-| 125m | 2048 | 32 | a100_40gb | 39.86 | 39.86 | 26 | 1 | 832 | 1986 | 4067922 | 127122 | 1703936 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
-| 125m | 2048 | 16 | a100_40gb | 40.93 | 40.93 | 26 | 1 | 416 | 1019 | 2088560 | 130535 | 851968 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
-| 125m | 2048 | 8 | a100_40gb | 42.75 | 42.75 | 26 | 3 | 624 | 532 | 1090678 | 136334 | 1277952 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
-| 125m | 1024 | 8 | a100_40gb | 40.89 | 40.89 | 52 | 3 | 1248 | 1158 | 1186314 | 148289 | 1277952 | bf16 | PURE | FULL_SHARD | False | False | 124525056 |
-| 125m | 512 | 8 | a100_40gb | 40.26 | 40.26 | 104 | 3 | 2496 | 2448 | 1253886 | 156735 | 1277952 | bf16 | PURE | FULL_SHARD | False | False | 124131840 |
+| Model | SeqLen (T) | # GPUs | GPU | MFU | HFU | Model TFLOP| MicroBatchSize | GradAccum | GlobalBatchSize | Throughput (S/s) | Throughput (T/s) | Throughput (T/s/GPU) | GlobalBatchSize (T) | Precision | MP Mode | Sharding Strategy | Activation Checkpointing | Activation CPUOffload | NumParams |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| 70b | 2048 | 128 | a100_40gb | 48.91 | 65.21 | 152 | 4 | 1 | 512 | 23 | 48194 | 376 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
+| 70b | 2048 | 64 | a100_40gb | 35.87 | 47.82 | 111 | 2 | 1 | 128 | 8 | 17672 | 276 | 262144 | bf16 | PURE | FULL_SHARD | True | False | 64862437376 |
+| 30b | 2048 | 128 | a100_40gb | 52.25 | 69.66 | 163 | 6 | 1 | 768 | 54 | 110803 | 865 | 1572864 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
+| 30b | 2048 | 32 | a100_40gb | 51.74 | 68.98 | 161 | 4 | 1 | 128 | 13 | 27431 | 857 | 262144 | bf16 | PURE | FULL_SHARD | True | False | 29975214080 |
+| 13b | 8192 | 8 | a100_40gb | 43.95 | 58.6 | 137 | 1 | 16 | 128 | 1 | 11258 | 1407 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12885411840 |
+| 13b | 4096 | 8 | a100_40gb | 44.85 | 59.8 | 139 | 2 | 16 | 256 | 3 | 12830 | 1603 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12864440320 |
+| 13b | 2048 | 128 | a100_40gb | 51.93 | 69.24 | 162 | 16 | 1 | 2048 | 123 | 252444 | 1972 | 4194304 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
+| 13b | 2048 | 64 | a100_40gb | 52.04 | 69.39 | 162 | 16 | 1 | 1024 | 61 | 126479 | 1976 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
+| 13b | 2048 | 32 | a100_40gb | 52.62 | 70.16 | 164 | 14 | 1 | 448 | 31 | 63946 | 1998 | 917504 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
+| 13b | 2048 | 16 | a100_40gb | 52.5 | 70.0 | 163 | 10 | 1 | 160 | 15 | 31900 | 1993 | 327680 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
+| 13b | 2048 | 8 | a100_40gb | 43.94 | 58.58 | 137 | 4 | 16 | 512 | 6 | 13347 | 1668 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12853954560 |
+| 13b | 1024 | 8 | a100_40gb | 44.07 | 58.76 | 137 | 8 | 16 | 1024 | 13 | 13817 | 1727 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12848711680 |
+| 13b | 512 | 8 | a100_40gb | 44.28 | 59.04 | 138 | 16 | 16 | 2048 | 27 | 14108 | 1763 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 12846090240 |
+| 7b | 16384 | 8 | a100_40gb | 47.65 | 63.53 | 148 | 1 | 4 | 32 | 1 | 17998 | 2249 | 524288 | bf16 | PURE | FULL_SHARD | True | False | 6717579264 |
+| 7b | 8192 | 8 | a100_40gb | 49.04 | 65.38 | 153 | 3 | 4 | 96 | 2 | 23098 | 2887 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6684024832 |
+| 7b | 4096 | 8 | a100_40gb | 50.11 | 66.82 | 156 | 6 | 4 | 192 | 6 | 26930 | 3366 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6667247616 |
+| 7b | 2048 | 128 | a100_40gb | 50.14 | 66.85 | 156 | 18 | 1 | 2304 | 226 | 463749 | 3623 | 4718592 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
+| 7b | 2048 | 64 | a100_40gb | 50.73 | 67.64 | 158 | 18 | 1 | 1152 | 114 | 234614 | 3665 | 2359296 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
+| 7b | 2048 | 32 | a100_40gb | 51.55 | 68.73 | 160 | 18 | 1 | 576 | 58 | 119202 | 3725 | 1179648 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
+| 7b | 2048 | 16 | a100_40gb | 50.44 | 67.26 | 157 | 16 | 1 | 256 | 28 | 58322 | 3645 | 524288 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
+| 7b | 2048 | 8 | a100_40gb | 50.92 | 67.89 | 158 | 12 | 4 | 384 | 14 | 29436 | 3679 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6658859008 |
+| 7b | 1024 | 8 | a100_40gb | 51.31 | 68.42 | 160 | 24 | 4 | 768 | 30 | 30833 | 3854 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6654664704 |
+| 7b | 512 | 8 | a100_40gb | 50.85 | 67.8 | 158 | 48 | 4 | 1536 | 60 | 31167 | 3895 | 786432 | bf16 | PURE | FULL_SHARD | True | False | 6652567552 |
+| 3b | 32768 | 8 | a100_40gb | 46.03 | 61.37 | 143 | 1 | 4 | 32 | 0 | 23640 | 2955 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 2730480640 |
+| 3b | 16384 | 8 | a100_40gb | 46.14 | 61.52 | 143 | 2 | 8 | 128 | 2 | 35726 | 4465 | 2097152 | bf16 | PURE | FULL_SHARD | True | False | 2688537600 |
+| 3b | 8192 | 8 | a100_40gb | 55.13 | 55.13 | 172 | 1 | 8 | 64 | 6 | 57193 | 7149 | 524288 | bf16 | PURE | FULL_SHARD | False | False | 2667566080 |
+| 3b | 4096 | 8 | a100_40gb | 56.18 | 56.18 | 175 | 2 | 8 | 128 | 17 | 70223 | 8777 | 524288 | bf16 | PURE | FULL_SHARD | False | False | 2657080320 |
+| 3b | 2048 | 128 | a100_40gb | 54.8 | 54.8 | 170 | 6 | 1 | 768 | 596 | 1220885 | 9538 | 1572864 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
+| 3b | 2048 | 64 | a100_40gb | 55.94 | 55.94 | 174 | 6 | 1 | 384 | 304 | 623167 | 9736 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
+| 3b | 2048 | 32 | a100_40gb | 56.96 | 56.96 | 177 | 6 | 1 | 192 | 154 | 317261 | 9914 | 393216 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
+| 3b | 2048 | 16 | a100_40gb | 56.02 | 56.02 | 174 | 5 | 1 | 80 | 76 | 156013 | 9750 | 163840 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
+| 3b | 2048 | 8 | a100_40gb | 57.82 | 57.82 | 180 | 5 | 8 | 320 | 39 | 80520 | 10065 | 655360 | bf16 | PURE | FULL_SHARD | False | False | 2651837440 |
+| 3b | 1024 | 8 | a100_40gb | 58.14 | 58.14 | 181 | 10 | 8 | 640 | 83 | 85854 | 10731 | 655360 | bf16 | PURE | FULL_SHARD | False | False | 2649216000 |
+| 3b | 512 | 8 | a100_40gb | 59.49 | 59.49 | 185 | 20 | 8 | 1280 | 176 | 90596 | 11324 | 655360 | bf16 | PURE | FULL_SHARD | False | False | 2647905280 |
+| 1b | 32768 | 8 | a100_40gb | 45.07 | 60.1 | 140 | 1 | 4 | 32 | 1 | 40762 | 5095 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 1378865152 |
+| 1b | 16384 | 8 | a100_40gb | 55.23 | 55.23 | 172 | 1 | 8 | 64 | 4 | 77723 | 9715 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1345310720 |
+| 1b | 8192 | 8 | a100_40gb | 55.29 | 55.29 | 172 | 2 | 8 | 128 | 13 | 107799 | 13474 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1328533504 |
+| 1b | 4096 | 8 | a100_40gb | 55.85 | 55.85 | 174 | 4 | 8 | 256 | 32 | 134851 | 16856 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1320144896 |
+| 1b | 2048 | 128 | a100_40gb | 54.41 | 54.41 | 169 | 10 | 1 | 1280 | 1165 | 2386897 | 18647 | 2621440 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
+| 1b | 2048 | 64 | a100_40gb | 55.44 | 55.44 | 172 | 10 | 1 | 640 | 593 | 1216104 | 19001 | 1310720 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
+| 1b | 2048 | 32 | a100_40gb | 45.39 | 45.39 | 141 | 10 | 1 | 320 | 243 | 497782 | 15555 | 655360 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
+| 1b | 2048 | 16 | a100_40gb | 55.69 | 55.69 | 173 | 8 | 1 | 128 | 149 | 305372 | 19085 | 262144 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
+| 1b | 2048 | 8 | a100_40gb | 56.23 | 56.23 | 175 | 8 | 8 | 512 | 75 | 154171 | 19271 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1315950592 |
+| 1b | 1024 | 8 | a100_40gb | 57.02 | 57.02 | 177 | 16 | 8 | 1024 | 163 | 167677 | 20959 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1313853440 |
+| 1b | 512 | 8 | a100_40gb | 57.1 | 57.1 | 178 | 32 | 8 | 2048 | 340 | 174256 | 21782 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 1312804864 |
+| 760m | 32768 | 8 | a100_40gb | 44.53 | 59.37 | 138 | 1 | 4 | 32 | 1 | 57464 | 7183 | 1048576 | bf16 | PURE | FULL_SHARD | True | False | 807656448 |
+| 760m | 16384 | 8 | a100_40gb | 53.26 | 53.26 | 166 | 1 | 4 | 32 | 6 | 111316 | 13914 | 524288 | bf16 | PURE | FULL_SHARD | False | False | 782490624 |
+| 760m | 8192 | 8 | a100_40gb | 53.12 | 53.12 | 165 | 3 | 4 | 96 | 19 | 160853 | 20106 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 769907712 |
+| 760m | 4096 | 8 | a100_40gb | 53.0 | 53.0 | 165 | 6 | 4 | 192 | 50 | 206909 | 25863 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 763616256 |
+| 760m | 2048 | 128 | a100_40gb | 50.73 | 50.73 | 158 | 12 | 1 | 1536 | 1808 | 3704382 | 28940 | 3145728 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
+| 760m | 2048 | 64 | a100_40gb | 51.44 | 51.44 | 160 | 12 | 1 | 768 | 917 | 1878030 | 29344 | 1572864 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
+| 760m | 2048 | 32 | a100_40gb | 51.97 | 51.97 | 162 | 12 | 1 | 384 | 463 | 948745 | 29648 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
+| 760m | 2048 | 16 | a100_40gb | 51.9 | 51.9 | 161 | 12 | 1 | 192 | 231 | 473723 | 29607 | 393216 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
+| 760m | 2048 | 8 | a100_40gb | 52.89 | 52.89 | 165 | 12 | 4 | 384 | 117 | 241389 | 30173 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 760470528 |
+| 760m | 1024 | 8 | a100_40gb | 53.63 | 53.63 | 167 | 24 | 4 | 768 | 261 | 267376 | 33422 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 758897664 |
+| 760m | 512 | 8 | a100_40gb | 53.47 | 53.47 | 166 | 48 | 4 | 1536 | 545 | 279504 | 34938 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 758111232 |
+| 350m | 32768 | 8 | a100_40gb | 51.55 | 51.55 | 160 | 1 | 4 | 32 | 3 | 107329 | 13416 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 387442688 |
+| 350m | 16384 | 8 | a100_40gb | 51.78 | 51.78 | 161 | 2 | 4 | 64 | 11 | 183175 | 22896 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 370665472 |
+| 350m | 8192 | 8 | a100_40gb | 51.39 | 51.39 | 160 | 4 | 4 | 128 | 34 | 279466 | 34933 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 362276864 |
+| 350m | 4096 | 8 | a100_40gb | 50.38 | 50.38 | 157 | 8 | 4 | 256 | 91 | 374670 | 46833 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 358082560 |
+| 350m | 2048 | 128 | a100_40gb | 45.61 | 45.61 | 142 | 18 | 1 | 2304 | 3245 | 6647647 | 51934 | 4718592 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
+| 350m | 2048 | 64 | a100_40gb | 46.27 | 46.27 | 144 | 18 | 1 | 1152 | 1646 | 3372118 | 52689 | 2359296 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
+| 350m | 2048 | 32 | a100_40gb | 47.26 | 47.26 | 147 | 18 | 1 | 576 | 840 | 1721978 | 53811 | 1179648 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
+| 350m | 2048 | 16 | a100_40gb | 48.66 | 48.66 | 151 | 18 | 1 | 288 | 432 | 886622 | 55413 | 589824 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
+| 350m | 2048 | 8 | a100_40gb | 49.17 | 49.17 | 153 | 16 | 4 | 512 | 218 | 447963 | 55995 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 355985408 |
+| 350m | 1024 | 8 | a100_40gb | 48.73 | 48.73 | 152 | 32 | 4 | 1024 | 488 | 500184 | 62523 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 354936832 |
+| 350m | 512 | 8 | a100_40gb | 48.39 | 48.39 | 150 | 64 | 4 | 2048 | 1035 | 530277 | 66284 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 354412544 |
+| 125m | 32768 | 8 | a100_40gb | 47.27 | 47.27 | 147 | 1 | 4 | 32 | 7 | 261208 | 32651 | 1048576 | bf16 | PURE | FULL_SHARD | False | False | 148904448 |
+| 125m | 16384 | 8 | a100_40gb | 46.77 | 46.77 | 145 | 2 | 3 | 48 | 27 | 443876 | 55484 | 786432 | bf16 | PURE | FULL_SHARD | False | False | 136321536 |
+| 125m | 8192 | 8 | a100_40gb | 46.94 | 46.94 | 146 | 5 | 3 | 120 | 84 | 694868 | 86858 | 983040 | bf16 | PURE | FULL_SHARD | False | False | 130030080 |
+| 125m | 4096 | 8 | a100_40gb | 44.82 | 44.82 | 139 | 13 | 3 | 312 | 224 | 921297 | 115162 | 1277952 | bf16 | PURE | FULL_SHARD | False | False | 126884352 |
+| 125m | 2048 | 128 | a100_40gb | 38.86 | 38.86 | 121 | 26 | 1 | 3328 | 7746 | 15863837 | 123936 | 6815744 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
+| 125m | 2048 | 64 | a100_40gb | 39.27 | 39.27 | 122 | 26 | 1 | 1664 | 3913 | 8015010 | 125234 | 3407872 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
+| 125m | 2048 | 32 | a100_40gb | 39.86 | 39.86 | 124 | 26 | 1 | 832 | 1986 | 4067922 | 127122 | 1703936 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
+| 125m | 2048 | 16 | a100_40gb | 40.93 | 40.93 | 127 | 26 | 1 | 416 | 1019 | 2088560 | 130535 | 851968 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
+| 125m | 2048 | 8 | a100_40gb | 42.75 | 42.75 | 133 | 26 | 3 | 624 | 532 | 1090678 | 136334 | 1277952 | bf16 | PURE | FULL_SHARD | False | False | 125311488 |
+| 125m | 1024 | 8 | a100_40gb | 40.89 | 40.89 | 127 | 52 | 3 | 1248 | 1158 | 1186314 | 148289 | 1277952 | bf16 | PURE | FULL_SHARD | False | False | 124525056 |
+| 125m | 512 | 8 | a100_40gb | 40.26 | 40.26 | 125 | 104 | 3 | 2496 | 2448 | 1253886 | 156735 | 1277952 | bf16 | PURE | FULL_SHARD | False | False | 124131840 |
diff --git a/scripts/train/benchmarking/collect_results.py b/scripts/train/benchmarking/collect_results.py
index 050390b743..d3691e951c 100644
--- a/scripts/train/benchmarking/collect_results.py
+++ b/scripts/train/benchmarking/collect_results.py
@@ -6,9 +6,10 @@
import math
from typing import Any, Dict, List, Union
-from mcli import sdk as msdk
+from composer.callbacks.speed_monitor import \
+ GPU_AVAILABLE_FLOPS as GPU_FLOP_DICT
-GPU_AVAILABLE_FLOPS = 312_000_000_000_000
+from mcli import sdk as msdk
def str_to_bool(value: Union[bool, str]):
@@ -46,13 +47,19 @@ def parse_args():
def get_runs(args: argparse.Namespace):
- runs = [r for r in msdk.get_runs() if args.project in r.name]
+ runs = [
+ r for r in msdk.get_runs(include_details=True)
+ if args.project in r.name.split('-')[0] and
+ r.status == msdk.RunStatus('COMPLETED')
+ ]
for filter in args.filters:
runs = [r for r in runs if filter in r.name]
def sort_key(r: msdk.Run):
model_name = r.name.split('-')[2]
- num_gpu = r.config.gpu_num
+ num_gpu = r.gpus
+ gpu_type = r.gpu_type
+ model_precision = r.submitted_config.parameters['precision']
if model_name[-1] == 'm':
model_name_size = 1e6
elif model_name[-1] == 'b':
@@ -61,9 +68,12 @@ def sort_key(r: msdk.Run):
print(model_name)
raise ValueError
model_size = int(model_name[:-1])
- return (model_name_size, model_size, r.config.parameters['max_seq_len'],
- num_gpu, r.config.parameters['global_train_batch_size'])
+ return (gpu_type, model_precision, model_name_size, model_size,
+ r.submitted_config.parameters['max_seq_len'], num_gpu,
+ r.submitted_config.parameters['global_train_batch_size'])
+ unique_runs = {sort_key(i): i for i in runs}
+ runs = [unique_runs[r] for r in unique_runs]
runs.sort(reverse=True, key=sort_key)
return runs
@@ -83,17 +93,7 @@ def filter_runs(runs: List[msdk.Run]):
pop_runs = []
for run in runs:
- if run.status in [
- msdk.RunStatus('FAILED_PULL'),
- msdk.RunStatus('PENDING'),
- msdk.RunStatus('QUEUED'),
- msdk.RunStatus('RUNNING'),
- msdk.RunStatus('SCHEDULED'),
- msdk.RunStatus('STARTING'),
- msdk.RunStatus('STOPPED'),
- msdk.RunStatus('STOPPING'),
- msdk.RunStatus('TERMINATING'),
- ]:
+ if run.status != msdk.RunStatus('COMPLETED'):
print(f'run {run.name} has run status {run.status}')
pop_runs.append(run)
for run in pop_runs:
@@ -106,13 +106,22 @@ def parse_run(run: msdk.Run) -> Dict[str, Any]:
n_params = micro_batchsize = throughput = -1
model_name = run.name.split('-')[2]
- gpu_num = run.config.gpu_num
- gpu_type = run.config.gpu_type
-
- fsdp_config = run.config.parameters['fsdp_config']
-
- seq_len = run.config.parameters['max_seq_len']
- global_train_batch_size = run.config.parameters['global_train_batch_size']
+ gpus = run.gpus
+ gpu_type = run.gpu_type
+
+ if 'h100' in gpu_type:
+ gpu_type = 'h100-sxm'
+ if 'a100' in gpu_type:
+ gpu_type = 'a100'
+ GPU_AVAILABLE_FLOPS = GPU_FLOP_DICT[gpu_type][
+ run.submitted_config.parameters['precision']]
+
+ gpu_type = run.gpu_type
+ fsdp_config = run.submitted_config.parameters['fsdp_config']
+
+ seq_len = run.submitted_config.parameters['max_seq_len']
+ global_train_batch_size = run.submitted_config.parameters[
+ 'global_train_batch_size']
activation_checkpointing = fsdp_config['activation_checkpointing']
logs = msdk.get_run_logs(run)
@@ -138,8 +147,8 @@ def parse_run(run: msdk.Run) -> Dict[str, Any]:
throughput = float(line.split(' ')[-1])
break
- d_model = run.config.parameters['model']['d_model']
- n_layers = run.config.parameters['model']['n_layers']
+ d_model = run.submitted_config.parameters['model']['d_model']
+ n_layers = run.submitted_config.parameters['model']['n_layers']
# mfu is approximated using thoughtput and param count
# the number of paramters is approximately the number of multiply-accumulates (MAC) in the network
@@ -153,31 +162,36 @@ def parse_run(run: msdk.Run) -> Dict[str, Any]:
attn_flops_per_seq = n_layers * 2 * 2 * (d_model * (seq_len**2))
# there are 2 ops in bwd pass and 1 in fwd pass so we mult by 3
mfu_w_attn = (3 * flops_per_seq + 3 * attn_flops_per_seq) * throughput / (
- gpu_num * GPU_AVAILABLE_FLOPS)
+ gpus * GPU_AVAILABLE_FLOPS)
if activation_checkpointing:
hfu_w_attn = (4 * flops_per_seq + 4 * attn_flops_per_seq
- ) * throughput / (gpu_num * GPU_AVAILABLE_FLOPS)
+ ) * throughput / (gpus * GPU_AVAILABLE_FLOPS)
else:
hfu_w_attn = mfu_w_attn
+ model_tflop = int(
+ (3 * flops_per_seq + 3 * attn_flops_per_seq) * throughput / gpus / 1e12)
+
return {
'Model':
model_name,
'SeqLen (T)':
seq_len,
'# GPUs':
- gpu_num,
+ gpus,
'GPU':
gpu_type,
'MFU':
round(mfu_w_attn * 100, 2),
'HFU':
round(hfu_w_attn * 100, 2),
+ 'Model TFLOP':
+ model_tflop,
'MicroBatchSize':
micro_batchsize,
'GradAccum':
- math.ceil(global_train_batch_size / gpu_num / micro_batchsize),
+ math.ceil(global_train_batch_size / gpus / micro_batchsize),
'GlobalBatchSize':
global_train_batch_size,
'Throughput (S/s)':
@@ -185,11 +199,11 @@ def parse_run(run: msdk.Run) -> Dict[str, Any]:
'Throughput (T/s)':
int(throughput * seq_len),
'Throughput (T/s/GPU)':
- int(throughput * seq_len / gpu_num),
+ int(throughput * seq_len / gpus),
'GlobalBatchSize (T)':
global_train_batch_size * seq_len,
'Precision':
- run.config.parameters['precision'],
+ run.submitted_config.parameters['precision'],
'MP Mode':
fsdp_config['mixed_precision'],
'Sharding Strategy':
diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py
index f7db0613ef..6530e79b0b 100644
--- a/scripts/train/benchmarking/submit_benchmarks.py
+++ b/scripts/train/benchmarking/submit_benchmarks.py
@@ -62,7 +62,7 @@ def parse_args():
type=str,
default=['bf16'],
nargs='+',
- choices=['bf16', 'fp16'])
+ choices=['bf16', 'fp16', 'fp8'])
parser.add_argument('--fsdp_config_mixed_precision',
type=str,
default='PURE')
@@ -71,6 +71,31 @@ def parse_args():
nargs='?',
const=True,
default=None)
+ parser.add_argument('--fsdp_config_shard_strategy',
+ type=str,
+ nargs='?',
+ const=True,
+ default=None)
+ parser.add_argument('--fsdp_config_limit_all_gathers',
+ type=str_to_bool,
+ nargs='?',
+ const=True,
+ default=None)
+ parser.add_argument('--fsdp_config_forward_prefetch',
+ type=str_to_bool,
+ nargs='?',
+ const=True,
+ default=None)
+ parser.add_argument('--fsdp_config_backward_prefetch',
+ type=str,
+ nargs='?',
+ const=True,
+ default=None)
+ parser.add_argument('--activation_cpu_offload',
+ type=str_to_bool,
+ nargs='?',
+ const=True,
+ default=None)
parser.add_argument(
'-s',
'--seq_len_exp',
@@ -121,7 +146,7 @@ def parse_args():
parser.add_argument('-c',
'--clusters',
type=str,
- default=['r7z2'],
+ default=['r1z1'],
nargs='+',
choices=CLUSTER_INFO.keys())
known_args = parser.parse_known_args()[0]
@@ -136,7 +161,7 @@ def parse_args():
parser.add_argument('-g',
'--gpu_nums',
type=int,
- default=[16],
+ default=[8],
nargs='+',
choices=_gpu_nums)
@@ -158,14 +183,13 @@ def parse_args():
const=True,
default=True)
- parser.add_argument('--priority', type=str, default='low')
+ parser.add_argument('--priority', type=str, default='lowest')
parser.add_argument('--RUN',
type=str_to_bool,
nargs='?',
const=True,
default=False)
-
return parser.parse_args()
@@ -236,19 +260,26 @@ def get_valid_gpu_lim(cluster: str, gpu_type: str):
raise ValueError
-def mod_parameters(parameters: Dict[str, Any],
- max_seq_len: int,
- global_train_batch_size: int,
- precision: str,
- fsdp_config_mixed_precision: str = 'DEFAULT',
- fsdp_config_activation_checkpointing: Optional[bool] = None,
- run_name: str = '',
- data_remote: Optional[str] = None,
- max_duration: str = '30ba',
- eval_interval: int = 0,
- microbatch_size: Optional[Union[int, str]] = None,
- wandb: bool = True,
- pad_vocab_multiple: Optional[int] = None):
+def mod_parameters(
+ parameters: Dict[str, Any],
+ max_seq_len: int,
+ global_train_batch_size: int,
+ precision: str,
+ fsdp_config_mixed_precision: str = 'DEFAULT',
+ fsdp_config_activation_checkpointing: Optional[bool] = None,
+ fsdp_config_shard_strategy: Optional[str] = None,
+ fsdp_config_forward_prefetch: Optional[bool] = None,
+ fsdp_config_backward_prefetch: Optional[str] = None,
+ fsdp_config_limit_all_gathers: Optional[bool] = None,
+ activation_cpu_offload: Optional[bool] = None,
+ run_name: str = '',
+ data_remote: Optional[str] = None,
+ max_duration: str = '30ba',
+ eval_interval: int = 0,
+ microbatch_size: Optional[Union[int, str]] = None,
+ wandb: bool = True,
+ pad_vocab_multiple: Optional[int] = None,
+):
if run_name:
parameters['run_name'] = run_name
if data_remote is not None:
@@ -271,9 +302,9 @@ def mod_parameters(parameters: Dict[str, Any],
parameters['max_seq_len'] = max_seq_len
parameters['model']['max_seq_len'] = max_seq_len
- parameters['model']['attn_impl'] = args.attn_impl
+ parameters['model']['attn_config']['attn_impl'] = args.attn_impl
- parameters['model']['low_precision_layernorm'] = True
+ parameters['model']['norm_type'] = 'low_precision_layernorm'
# Pad vocab size to multiple of N for A100 perf
if pad_vocab_multiple:
@@ -305,9 +336,21 @@ def mod_parameters(parameters: Dict[str, Any],
if fsdp_config_activation_checkpointing is not None:
parameters['fsdp_config'][
'activation_checkpointing'] = fsdp_config_activation_checkpointing
-
- parameters['fsdp_config']['activation_checkpointing_reentrant'] = False
- parameters['fsdp_config']['limit_all_gathers'] = True
+ if fsdp_config_shard_strategy is not None:
+ parameters['fsdp_config'][
+ 'sharding_strategy'] = fsdp_config_shard_strategy
+ if fsdp_config_limit_all_gathers is not None:
+ parameters['fsdp_config'][
+ 'limit_all_gathers'] = fsdp_config_limit_all_gathers
+ if fsdp_config_forward_prefetch is not None:
+ parameters['fsdp_config'][
+ 'forward_prefetch'] = fsdp_config_forward_prefetch
+ if fsdp_config_backward_prefetch is not None:
+ parameters['fsdp_config'][
+ 'backward_prefetch'] = fsdp_config_backward_prefetch
+ if activation_cpu_offload is not None:
+ parameters['fsdp_config'][
+ 'activation_cpu_offload'] = activation_cpu_offload
if wandb:
# add wandb
@@ -332,7 +375,7 @@ def get_integrations(project: str,
}
git_integration.update({
'integration_type': 'git_repo',
- 'git_repo': 'mosaicml/examples',
+ 'git_repo': 'mosaicml/llm-foundry',
'pip_install': '-e .[gpu]'
})
@@ -351,30 +394,42 @@ def get_integrations(project: str,
def run_config(config: Tuple[str, int, int, str, str, int, str],
args: argparse.Namespace):
model_yaml, max_seq_len, global_train_batch_size, cluster, gpu_type, gpu_num, precision = config
-
- integrations = get_integrations(
- args.project,
- git_branch=args.git_branch,
- git_commit=args.git_commit,
- wandb=args.wandb) # point to git repo and potentially wandb
-
- # Define our command
- if args.data_remote is not None:
- command = """
- cd examples/scripts
-
- composer train/train.py /mnt/config/parameters.yaml
+ integrations = [
+ {
+ 'integration_type': 'git_repo',
+ 'git_repo': 'mosaicml/llm-foundry',
+ 'git_branch': 'v0.3.0',
+ 'pip_install': '-e .[gpu]',
+ },
+ {
+ 'integration_type': 'wandb',
+ 'entity': 'mosaic-ml',
+ 'project': args.project
+ },
+ ]
+
+ command = ''
+ if gpu_type == 'h100_80gb' and 'fp8' in precision: # Required for flash-attn and FP8 training
+ command += f"""
+ pip install flash-attn==1.0.7 --no-build-isolation
+ pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.10
+ pip uninstall install pydantic --yes
+ pip install pydantic==1.9.0
"""
+
+ if args.data_remote is None:
+ command += f"""
+ cd llm-foundry/scripts
+ python data_prep/convert_dataset_hf.py --dataset c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens {max_seq_len} --tokenizer gpt2 --eos_text '<|endoftext|>'
+ composer train/train.py /mnt/config/parameters.yaml
+ """
else:
command = f"""
- cd examples/scripts
-
- python data_prep/convert_dataset_hf.py --dataset c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens {max_seq_len} --tokenizer gpt2 --eos_text '<|endoftext|>'
-
- composer train/train.py /mnt/config/parameters.yaml
- """
+ cd llm-foundry/scripts
+ composer train/train.py /mnt/config/parameters.yaml
+ """
- path = os.path.join('../yamls/mpt', model_yaml)
+ path = os.path.join('../yamls/pretrain', 'mpt-' + model_yaml)
parameters = get_parameters(path)
model_name = '-'.join(model_yaml.split('.')[-2].split('/')[-2:]).replace(
@@ -391,23 +446,28 @@ def run_config(config: Tuple[str, int, int, str, str, int, str],
_name = name
name = name[:name_len_lim]
print(f'Shortening {_name} to {name} ({name_len_lim} chars)')
-
microbatch_size = args.microbatch_size or 'auto'
assert isinstance(microbatch_size, (int, str))
parameters = mod_parameters(
parameters,
max_seq_len,
global_train_batch_size,
- precision,
- fsdp_config_mixed_precision=args.fsdp_config_mixed_precision,
+ 'amp_' + precision,
fsdp_config_activation_checkpointing=args.
fsdp_config_activation_checkpointing,
+ fsdp_config_limit_all_gathers=args.fsdp_config_limit_all_gathers,
+ fsdp_config_shard_strategy=args.fsdp_config_shard_strategy,
+ fsdp_config_forward_prefetch=args.fsdp_config_forward_prefetch,
+ fsdp_config_backward_prefetch=args.fsdp_config_backward_prefetch,
+ activation_cpu_offload=args.activation_cpu_offload,
run_name=name,
data_remote=args.data_remote,
microbatch_size=microbatch_size,
wandb=args.wandb,
- pad_vocab_multiple=args.pad_vocab_multiple)
-
+ pad_vocab_multiple=args.pad_vocab_multiple,
+ )
+ if gpu_type == 'h100_80gb' and precision == 'fp8':
+ parameters['model']['fc_type'] = 'te'
# Create run config mcli sdk/api
config = RunConfig(name=name,
gpu_type=gpu_type,
@@ -417,8 +477,8 @@ def run_config(config: Tuple[str, int, int, str, str, int, str],
integrations=integrations,
command=command,
parameters=parameters,
- scheduling=SchedulingConfig(priority=args.priority))
-
+ scheduling=SchedulingConfig(priority=args.priority,
+ resumable=True))
if args.RUN:
# Create the run from a config
run = create_run(config)
@@ -461,7 +521,6 @@ def run_check_dtms(num_gpus: int, dtms: int, batch_size: int):
if __name__ == '__main__':
args = parse_args()
-
n_jobs = 0
for max_seq_len in get_max_seq_lens(args.seq_len_exp):
for cluster in args.clusters:
@@ -497,7 +556,6 @@ def run_check_dtms(num_gpus: int, dtms: int, batch_size: int):
global_train_batch_size,
cluster, gpu_type,
gpu_num, precision)
- print(config)
run_config(config, args)
n_jobs += 1
diff --git a/scripts/train/benchmarking/sweep.sh b/scripts/train/benchmarking/sweep.sh
index 5d962b7c5c..97372ee6fd 100755
--- a/scripts/train/benchmarking/sweep.sh
+++ b/scripts/train/benchmarking/sweep.sh
@@ -2,34 +2,148 @@
PROJECT="tput"
GIT_COMMIT="v0.0.4"
-IMAGE="mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04"
-CLUSTER_80GB=YOUR_CLUSTER_80GB
-CLUSTER_40GB=YOUR_CLUSTER_40GB
+IMAGE="mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04"
+CLUSTER_40GB= # TODO
+
+for PRECISION in fp8 bf16
+do
+
+ # H100 80GB
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 40 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 32 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 24 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 14 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 10 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 6 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 8 --microbatch_size 3 --accum 21 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 7 --accum 1 --image $IMAGE1 --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 7 --accum 1 --image $IMAGE0 --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 6 --accum 1 --image $IMAGE1 --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 6 --accum 1 --image $IMAGE0 --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+
+ # INCREASE GPU COUNT
+ for GPU_NUM in 16 32 64
+ do
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g $GPU_NUM --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g $GPU_NUM --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g $GPU_NUM --microbatch_size 24 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g $GPU_NUM --microbatch_size 20 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g $GPU_NUM --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ done
+
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 16 --microbatch_size 10 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 16 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 16 --microbatch_size 10 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 32 --microbatch_size 6 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 32 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 64 --microbatch_size 6 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 64 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 32 --microbatch_size 14 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 70b.yaml -g 32 --microbatch_size 2 --accum 16 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 64 --microbatch_size 16 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 70b.yaml -g 64 --microbatch_size 8 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 11 11 --RUN -t ${PRECISION}
+
+ # SCALE SEQUENCE LENGTH
+ # seqlen 512
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 128 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --precision fp8 --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 9 9 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 128 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 9 9 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 96 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 9 9 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 56 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 9 9 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 40 --accum 6 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 9 9 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 64 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 9 9 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 20 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 9 9 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 8 --microbatch_size 12 --accum 21 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 9 9 --RUN -t ${PRECISION}
+ # seqlen 1024
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 64 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 10 10 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 64 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 10 10 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 48 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 10 10 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 18 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 10 10 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 20 --accum 6 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 10 10 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 64 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 10 10 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 40 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 10 10 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 8 --microbatch_size 6 --accum 21 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 10 10 --RUN -t ${PRECISION}
+ # seqlen 4096
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 16 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 12 12 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 16 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 12 12 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 12 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 12 12 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 7 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 12 12 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 5 --accum 6 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 12 12 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 16 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 12 12 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 10 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 12 12 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 8 --microbatch_size 1 --accum 21 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 12 12 --RUN -t ${PRECISION}
+ # seqlen 8192
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 8 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 13 13 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 8 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 13 13 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 6 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 13 13 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 3 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 13 13 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 3 --accum 6 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 13 13 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 8 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 13 13 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 5 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 13 13 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 8 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 13 13 --RUN -t ${PRECISION}
+ # seqlen 16384
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 4 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 14 14 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 4 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 14 14 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 3 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 14 14 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 2 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 14 14 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 1 --accum 6 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 14 14 --RUN --fsdp_config_activation_checkpointing false -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 4 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 14 14 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 3 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 14 14 --RUN -t ${PRECISION}
+ # seqlen 32768
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 2 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 15 15 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 2 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 15 15 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 15 15 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 1 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 15 15 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 3 --accum 6 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 15 15 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 2 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 15 15 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 1 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 15 15 --RUN -t ${PRECISION}
+ # seqlen 65536
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 16 16 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 16 16 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 16 16 --RUN --fsdp_config_activation_checkpointing true -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 16 16 --RUN --fsdp_config_activation_checkpointing true -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 16 16 --RUN -t ${PRECISION}
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type h100_80gb --cluster $CLUSTER_H100 -s 16 16 --RUN -t ${PRECISION}
+done
# A100 80GB
# seqlen 2048
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 32 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 40 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 32 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 24 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 14 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 10 --accum 6 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 32 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 20 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 14 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 10 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 6 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 8 --microbatch_size 3 --accum 21 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 7 --accum 1 --image $IMAGE1 --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 7 --accum 1 --image $IMAGE0 --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+
+python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 6 --accum 1 --image $IMAGE1 --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 6 --accum 1 --image $IMAGE0 --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+
# INCREASE GPU COUNT
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 16 32 64 --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 16 32 64 --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 16 32 64 --microbatch_size 24 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 16 32 64 --microbatch_size 20 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+for GPU_NUM in 16 32 64
+do
+ python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g $GPU_NUM --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+ python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g $GPU_NUM --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g $GPU_NUM --microbatch_size 24 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g $GPU_NUM --microbatch_size 20 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+ python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g $GPU_NUM --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+done
+
python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 16 --microbatch_size 10 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 16 32 64 --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 16 --microbatch_size 24 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 16 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 16 --microbatch_size 10 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 32 64 --microbatch_size 12 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 32 64 --microbatch_size 32 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
+python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 32 --microbatch_size 6 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 32 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 64 --microbatch_size 6 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
+python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 64 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 32 --microbatch_size 14 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
python submit_benchmarks.py --project $PROJECT -m 70b.yaml -g 32 --microbatch_size 2 --accum 16 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 64 --microbatch_size 16 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 11 11 --RUN
@@ -37,13 +151,13 @@ python submit_benchmarks.py --project $PROJECT -m 70b.yaml -g 64 --microb
# SCALE SEQUENCE LENGTH
# seqlen 512
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 128 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN
+python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 128 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --precision fp8 --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN
python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 128 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN
python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 96 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN
python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 56 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN
python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 40 --accum 6 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 128 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 80 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN
+python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 64 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN --fsdp_config_activation_checkpointing false
+python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 20 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN --fsdp_config_activation_checkpointing false
python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 8 --microbatch_size 12 --accum 21 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 9 9 --RUN
# seqlen 1024
python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 64 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 10 10 --RUN
@@ -71,7 +185,7 @@ python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_si
python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 3 --accum 6 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 13 13 --RUN --fsdp_config_activation_checkpointing false
python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 8 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 13 13 --RUN
python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 5 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 13 13 --RUN
-python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 8 --microbatch_size 1 --accum 21 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 13 13 --RUN
+python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 8 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 13 13 --RUN
# seqlen 16384
python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 4 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 14 14 --RUN
python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 4 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 14 14 --RUN
@@ -95,80 +209,3 @@ python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_si
python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 16 16 --RUN --fsdp_config_activation_checkpointing true
python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 16 16 --RUN
python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 1 --accum 2 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_80gb --cluster $CLUSTER_80GB -s 16 16 --RUN
-
-
-# A100 40GB
-
-# seqlen 2048
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 26 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 16 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 12 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 8 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 5 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 16 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 4 --accum 16 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-
-# INCREASE GPU COUNT
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 16 32 64 128 --microbatch_size 26 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 16 32 64 128 --microbatch_size 18 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 16 32 64 128 --microbatch_size 12 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 16 --microbatch_size 8 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 16 --microbatch_size 5 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 16 --microbatch_size 16 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 16 --microbatch_size 10 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 32 64 128 --microbatch_size 10 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 32 64 128 --microbatch_size 6 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 32 64 128 --microbatch_size 18 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 32 --microbatch_size 14 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 32 --microbatch_size 4 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 64 128 --microbatch_size 16 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 70b.yaml -g 64 --microbatch_size 2 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 30b.yaml -g 128 --microbatch_size 6 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-python submit_benchmarks.py --project $PROJECT -m 70b.yaml -g 128 --microbatch_size 4 --accum 1 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 11 11 --RUN
-
-# SCALE SEQUENCE LENGTH
-# seqlen 512
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 104 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 9 9 --RUN
-python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 64 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 9 9 --RUN
-python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 48 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 9 9 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 32 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 9 9 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 20 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 9 9 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 56 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 9 9 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 16 --accum 16 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 9 9 --RUN
-# seqlen 1024
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 52 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 10 10 --RUN
-python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 32 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 10 10 --RUN
-python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 24 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 10 10 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 16 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 10 10 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 10 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 10 10 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 28 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 10 10 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 8 --accum 16 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 10 10 --RUN
-# seqlen 4096
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 13 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 12 12 --RUN
-python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 8 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 12 12 --RUN
-python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 6 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 12 12 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 4 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 12 12 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 2 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 12 12 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 8 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 12 12 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 2 --accum 16 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 12 12 --RUN
-# seqlen 8192
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 5 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 13 13 --RUN
-python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 4 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 13 13 --RUN
-python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 3 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 13 13 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 2 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 13 13 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 1 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 13 13 --RUN --fsdp_config_activation_checkpointing false
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 3 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 13 13 --RUN
-python submit_benchmarks.py --project $PROJECT -m 13b.yaml -g 8 --microbatch_size 1 --accum 16 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 13 13 --RUN
-# seqlen 16384
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 2 --accum 3 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 14 14 --RUN
-python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 2 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 14 14 --RUN
-python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 1 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 14 14 --RUN
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 1 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 14 14 --RUN
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 2 --accum 8 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 14 14 --RUN
-python submit_benchmarks.py --project $PROJECT -m 7b.yaml -g 8 --microbatch_size 1 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 14 14 --RUN
-# seqlen 32768
-python submit_benchmarks.py --project $PROJECT -m 125m.yaml -g 8 --microbatch_size 1 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 15 15 --RUN
-python submit_benchmarks.py --project $PROJECT -m 350m.yaml -g 8 --microbatch_size 1 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 15 15 --RUN
-python submit_benchmarks.py --project $PROJECT -m 760m.yaml -g 8 --microbatch_size 1 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 15 15 --RUN --fsdp_config_activation_checkpointing true
-python submit_benchmarks.py --project $PROJECT -m 1b.yaml -g 8 --microbatch_size 1 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 15 15 --RUN --fsdp_config_activation_checkpointing true
-python submit_benchmarks.py --project $PROJECT -m 3b.yaml -g 8 --microbatch_size 1 --accum 4 --image $IMAGE --git_commit $GIT_COMMIT --gpu_type a100_40gb --cluster $CLUSTER_40GB -s 15 15 --RUN
diff --git a/scripts/train/train.py b/scripts/train/train.py
index 7358d58d2e..8c1c28eb5c 100644
--- a/scripts/train/train.py
+++ b/scripts/train/train.py
@@ -1,9 +1,11 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import copy
+import gc
import logging
import os
import sys
+import time
import warnings
from typing import Any, Dict, List, Optional, Union
@@ -11,6 +13,11 @@
from composer import Trainer
from composer.core import Evaluator
from composer.core.callback import Callback
+from composer.loggers import MosaicMLLogger
+from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR,
+ MOSAICML_PLATFORM_ENV_VAR)
+from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler,
+ cyclic_schedule)
from composer.utils import dist, get_device, reproducibility
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
@@ -210,6 +217,12 @@ def main(cfg: DictConfig) -> Trainer:
os.environ[
'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}'
+ # Set CUDA lazy loading
+ # This can save a bit of memory if not all modules are needed
+ cuda_load_lazy: bool = cfg.pop('cuda_load_lazy', True)
+ if cuda_load_lazy:
+ os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
+
# Set seed first
seed: int = pop_config(cfg, 'seed', must_exist=True)
reproducibility.seed_all(seed)
@@ -383,10 +396,18 @@ def main(cfg: DictConfig) -> Trainer:
'load_weights_only',
must_exist=False,
default_value=False)
+ load_strict_model_weights: bool = pop_config(cfg,
+ 'load_strict_model_weights',
+ must_exist=False,
+ default_value=True)
load_ignore_keys: Optional[List[str]] = pop_config(cfg,
'load_ignore_keys',
must_exist=False,
default_value=None)
+ compile_config: Optional[Dict[str, Any]] = pop_config(cfg,
+ 'compile_config',
+ must_exist=False,
+ default_value=None)
# Enable autoresume from model checkpoints if possible
autoresume_default: bool = False
if logged_cfg.get('run_name', None) is not None \
@@ -452,7 +473,44 @@ def main(cfg: DictConfig) -> Trainer:
loggers = [
build_logger(str(name), logger_cfg)
for name, logger_cfg in logger_configs.items()
- ] if logger_configs else None
+ ] if logger_configs else []
+
+ mosaicml_logger = next(
+ (logger for logger in loggers if isinstance(logger, MosaicMLLogger)),
+ None)
+ if mosaicml_logger is None:
+ if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower(
+ ) == 'true' and os.environ.get(MOSAICML_ACCESS_TOKEN_ENV_VAR):
+ # Adds mosaicml logger to composer if the run was sent from Mosaic platform, access token is set, and mosaic logger wasn't previously added
+ mosaicml_logger = MosaicMLLogger()
+ loggers.append(mosaicml_logger)
+
+ # Profiling
+ profiler: Optional[Profiler] = None
+ profiler_cfg: Optional[DictConfig] = pop_config(cfg,
+ 'profiler',
+ must_exist=False,
+ convert=False,
+ default_value=None)
+ if profiler_cfg:
+ profiler_schedule_cfg: Dict = pop_config(profiler_cfg,
+ 'schedule',
+ must_exist=True,
+ convert=True)
+ profiler_schedule = cyclic_schedule(**profiler_schedule_cfg)
+ # Only support json trace handler
+ profiler_trace_handlers: List[TraceHandler] = []
+ profiler_trace_cfg: Optional[Dict] = pop_config(profiler_cfg,
+ 'json_trace_handler',
+ must_exist=False,
+ default_value=None,
+ convert=True)
+ if profiler_trace_cfg:
+ profiler_trace_handlers.append(
+ JSONTraceHandler(**profiler_trace_cfg))
+ profiler = Profiler(**profiler_cfg,
+ trace_handlers=profiler_trace_handlers,
+ schedule=profiler_schedule)
# Callbacks
callbacks: List[Callback] = [
@@ -473,6 +531,10 @@ def main(cfg: DictConfig) -> Trainer:
tokenizer,
device_train_batch_size,
)
+
+ if mosaicml_logger is not None:
+ mosaicml_logger.log_metrics({'data_validated': time.time()})
+
## Evaluation
print('Building eval loader...')
evaluators = []
@@ -567,15 +629,19 @@ def main(cfg: DictConfig) -> Trainer:
save_weights_only=save_weights_only,
load_path=load_path,
load_weights_only=load_weights_only,
+ load_strict_model_weights=load_strict_model_weights,
load_ignore_keys=load_ignore_keys,
autoresume=autoresume,
python_log_level=python_log_level,
dist_timeout=dist_timeout,
+ profiler=profiler,
+ compile_config=compile_config,
)
print('Logging config')
log_config(logged_cfg)
torch.cuda.empty_cache()
+ gc.collect()
# Eval first if requested
if eval_first and trainer.state.timestamp.batch.value == 0:
diff --git a/scripts/train/yamls/pretrain/mpt-small-cpu.yaml b/scripts/train/yamls/pretrain/mpt-small-cpu.yaml
new file mode 100644
index 0000000000..cc04f11e44
--- /dev/null
+++ b/scripts/train/yamls/pretrain/mpt-small-cpu.yaml
@@ -0,0 +1,119 @@
+data_local: ./my-copy-c4
+data_remote: # If blank, files must be present in data_local
+max_seq_len: 128
+global_seed: 17
+
+# Run Name
+run_name: mpt_causal_lm_cpu # If left blank, will be read from env var $RUN_NAME
+
+# Model
+model:
+ name: mpt_causal_lm
+ init_device: cpu
+ d_model: 16
+ n_heads: 4
+ n_layers: 4
+ expansion_ratio: 5
+ max_seq_len: ${max_seq_len}
+ vocab_size: 50368
+ attn_config:
+ attn_impl: torch
+ loss_fn: torch_crossentropy
+
+# Tokenizer
+tokenizer:
+ name: EleutherAI/gpt-neox-20b
+ kwargs:
+ model_max_length: ${max_seq_len}
+
+# Dataloaders
+train_loader:
+ name: text
+ dataset:
+ local: ${data_local}
+ remote: ${data_remote}
+ split: train
+ shuffle: true
+ max_seq_len: ${max_seq_len}
+ shuffle_seed: ${global_seed}
+ drop_last: true
+ num_workers: 2
+
+eval_loader:
+ name: text
+ dataset:
+ local: ${data_local}
+ remote: ${data_remote}
+ split: val
+ shuffle: false
+ max_seq_len: ${max_seq_len}
+ shuffle_seed: ${global_seed}
+ drop_last: false
+ num_workers: 2
+
+# Optimization
+scheduler:
+ name: cosine_with_warmup
+ t_warmup: 100ba
+ alpha_f: 0.1
+
+optimizer:
+ name: decoupled_adamw
+ lr: 6.0e-4
+ betas:
+ - 0.9
+ - 0.95
+ eps: 1.0e-08
+ weight_decay: 0.0
+
+algorithms:
+ gradient_clipping:
+ clipping_type: norm
+ clipping_threshold: 1.0
+
+max_duration: 10ba
+eval_interval: 5ba
+eval_first: false
+eval_subset_num_batches: 5
+global_train_batch_size: 256
+autoresume: false
+
+# System
+seed: ${global_seed}
+device_eval_batch_size: 16
+device_train_microbatch_size: 16
+# device_train_microbatch_size: auto
+precision: fp32
+
+# FSDP
+fsdp_config:
+ sharding_strategy: FULL_SHARD
+ mixed_precision: PURE
+ activation_checkpointing: false
+ activation_checkpointing_reentrant: false
+ activation_cpu_offload: false
+ limit_all_gathers: true
+ verbose: false
+
+# Logging
+progress_bar: false
+log_to_console: true
+console_log_interval: 1ba
+
+callbacks:
+ speed_monitor:
+ window_size: 10
+ lr_monitor: {}
+ memory_monitor: {}
+ runtime_estimator: {}
+
+# Checkpoint to local filesystem or remote object store
+save_overwrite: true
+save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
+# save_interval: 500ba
+# save_folder: ./{run_name}/checkpoints
+# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints
+
+# Load from local filesystem or remote object store
+# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt
+# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt
diff --git a/setup.py b/setup.py
index e4af7255ea..55e82b9379 100644
--- a/setup.py
+++ b/setup.py
@@ -49,7 +49,7 @@
install_requires = [
'mosaicml@git+https://github.com/bmosaicml/composer.git@codetracing',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
- 'transformers>=4.33,<4.34',
+ 'transformers>=4.34.1,<4.35',
'mosaicml-streaming>=0.6,<0.7',
'torch>=1.13.1,<2.1.1',
'datasets>=2.14.5,<2.15',
@@ -114,9 +114,10 @@
extra_deps['all-cpu'] = set(
dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key)
extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps
- if key != 'gpu-flash2')
-extra_deps['all-flash2'] = set(
- dep for key, deps in extra_deps.items() for dep in deps if key != 'gpu')
+ if key not in {'gpu-flash2', 'all-cpu'})
+extra_deps['all-flash2'] = set(dep for key, deps in extra_deps.items()
+ for dep in deps
+ if key not in {'gpu', 'all', 'all-cpu'})
setup(
name=_PACKAGE_NAME,
diff --git a/tests/test_builders.py b/tests/test_builders.py
index adff8e55ee..0d24d2154f 100644
--- a/tests/test_builders.py
+++ b/tests/test_builders.py
@@ -1,11 +1,15 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
+import unittest.mock as mock
+from typing import Union
+
import pytest
+from composer.callbacks import Generate
from transformers import PreTrainedTokenizerBase
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
-from llmfoundry.utils.builders import build_tokenizer
+from llmfoundry.utils.builders import build_callback, build_tokenizer
@pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [
@@ -29,3 +33,48 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict):
assert tokenizer.model_max_length == tokenizer_kwargs[
'model_max_length']
assert isinstance(tokenizer, PreTrainedTokenizerBase)
+
+
+def test_build_callback_fails():
+ with pytest.raises(ValueError):
+ build_callback('nonexistent_callback', {})
+
+
+@pytest.mark.parametrize(
+ 'interval_key,interval_value',
+ [('interval', '10ba'), ('batch_log_interval', 10)],
+)
+def test_build_generate_callback(
+ interval_key: str,
+ interval_value: Union[str, int],
+):
+
+ with mock.patch.object(Generate, '__init__',
+ autospec=True) as mock_generate:
+ mock_generate.return_value = None
+ build_callback(
+ 'generate_callback', {
+ 'prompts': ['hello'],
+ interval_key: interval_value,
+ 'foo': 'bar',
+ 'something': 'else',
+ })
+
+ assert mock_generate.call_count == 1
+ _, _, kwargs = mock_generate.mock_calls[0]
+ assert kwargs['prompts'] == ['hello']
+ assert kwargs['interval'] == '10ba'
+ assert kwargs['something'] == 'else'
+ assert kwargs['foo'] == 'bar'
+
+
+def test_build_generate_callback_unspecified_interval():
+ with pytest.raises(KeyError):
+ with mock.patch.object(Generate, '__init__',
+ autospec=True) as mock_generate:
+ mock_generate.return_value = None
+ build_callback('generate_callback', {
+ 'prompts': ['hello'],
+ 'foo': 'bar',
+ 'something': 'else',
+ })
diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py
index 6495eccf65..656b6d52a6 100644
--- a/tests/test_dataloader.py
+++ b/tests/test_dataloader.py
@@ -3,22 +3,27 @@
import contextlib
import os
import pathlib
+import random
import shutil
import sys
import tempfile
from argparse import Namespace
from typing import Optional
+from unittest.mock import MagicMock
import pytest
import torch
+import transformers
from composer.utils import dist, using_torch_2
+from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import MDSWriter
from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
- build_text_dataloader)
+ build_text_dataloader,
+ get_tokens_per_batch_func)
from llmfoundry.utils.builders import build_tokenizer
# Add repo root to path so we can import scripts and test it
@@ -137,7 +142,7 @@ def test_correct_padding(tokenizer_name: str,
test_cfg.eval_loader,
tokenizer,
batch_size,
- )
+ ).dataloader
batch = next(iter(eval_loader))
assert batch['input_ids'].shape == torch.Size([batch_size, 2048])
@@ -228,7 +233,7 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool,
tokenizer_kwargs={'model_max_length': max_seq_len})
loader = build_text_denoising_dataloader(cfg, tokenizer,
- device_batch_size)
+ device_batch_size).dataloader
batch_ix = 0
for batch in loader:
for k in expected_keys:
@@ -287,7 +292,8 @@ def test_finetuning_dataloader(decoder_only_format: bool,
else:
expected_keys += ['decoder_attention_mask', 'decoder_input_ids']
- loader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
+ loader = build_finetuning_dataloader(cfg, tokenizer,
+ device_batch_size).dataloader
batch_ix = 0
for batch in loader:
for k in expected_keys:
@@ -541,7 +547,8 @@ def test_malformed_data(
match='Unable to tokenize example')
with error_context:
- dl = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
+ dl = build_finetuning_dataloader(cfg, tokenizer,
+ device_batch_size).dataloader
if not add_bad_data_error:
# +5 because we added samples with just bos/eos in each of prompt/response
@@ -552,3 +559,175 @@ def test_malformed_data(
actual_num_batches += 1
assert actual_num_batches == expected_num_batches
+
+
+@pytest.mark.parametrize('pad_token_id', [0, 100, 1000])
+@pytest.mark.parametrize('batch_size', [1, 8, 16])
+@pytest.mark.parametrize('model_max_length', [1024, 2048])
+@pytest.mark.parametrize('padding_side', ['left', 'right'])
+@pytest.mark.parametrize('add_decoder_input_ids', [True, False])
+def test_token_counting_func(pad_token_id: int, batch_size: int,
+ model_max_length: int, padding_side: str,
+ add_decoder_input_ids: bool):
+ gptt = transformers.AutoTokenizer.from_pretrained('gpt2')
+ gptt.pad_token_id = pad_token_id
+ gptt.model_max_length = model_max_length
+ gptt.padding_side = padding_side
+
+ batch_strings = []
+ expected_token_count = 0
+ for _ in range(batch_size):
+ sample_length = random.randint(1, model_max_length)
+ batch_strings.append(' '.join(['hello'] * sample_length))
+ expected_token_count += sample_length
+
+ batch_tokenized = gptt(batch_strings, padding=True, return_tensors='pt')
+
+ if add_decoder_input_ids:
+ decoder_batch_strings = []
+ decoder_expected_token_count = 0
+ for _ in range(batch_size):
+ sample_length = random.randint(1, model_max_length)
+ decoder_batch_strings.append(' '.join(['hello'] * sample_length))
+ decoder_expected_token_count += sample_length
+ expected_token_count += sample_length
+ batch_tokenized['decoder_input_ids'] = gptt(
+ decoder_batch_strings, padding=True,
+ return_tensors='pt')['input_ids']
+
+ token_counting_func = get_tokens_per_batch_func(
+ pad_token_id, decoder_only=not add_decoder_input_ids)
+
+ actual_token_count = token_counting_func(batch_tokenized)
+
+ assert actual_token_count == expected_token_count
+
+
+@pytest.mark.parametrize(
+ 'dataloader_type',
+ ['finetuning-hf', 'finetuning-streaming', 'denoising', 'text'])
+@pytest.mark.parametrize('pad_token_id', [100, None])
+@pytest.mark.parametrize('batch_size', [1, 8])
+@pytest.mark.parametrize('model_max_length', [1024])
+@pytest.mark.parametrize('padding_side', ['left'])
+def test_token_counting_func_dataloader_setting(
+ dataloader_type: str, pad_token_id: Optional[int], batch_size: int,
+ model_max_length: int, padding_side: str,
+ monkeypatch: pytest.MonkeyPatch):
+ gptt = transformers.AutoTokenizer.from_pretrained('gpt2')
+ gptt.pad_token_id = pad_token_id
+ gptt.model_max_length = model_max_length
+ gptt.padding_side = padding_side
+
+ batch_strings = []
+ expected_token_count = 0
+ for _ in range(batch_size):
+ sample_length = random.randint(
+ 1,
+ model_max_length) if pad_token_id is not None else model_max_length
+ batch_strings.append(' '.join(['hello'] * sample_length))
+ expected_token_count += sample_length
+
+ batch_tokenized = gptt(batch_strings,
+ padding=True if pad_token_id is not None else False,
+ return_tensors='pt')
+
+ if dataloader_type == 'denoising':
+ batch_tokenized['decoder_input_ids'] = batch_tokenized[
+ 'input_ids'].clone()
+ expected_token_count *= 2
+
+ common_args = {
+ 'drop_last': False,
+ 'num_workers': 0,
+ 'prefetch_factor': None if using_torch_2() else 2,
+ 'pin_memory': False,
+ 'persistent_workers': False,
+ 'timeout': 0
+ }
+
+ if dataloader_type == 'finetuning-hf':
+ cfg = DictConfig({
+ 'name': 'finetuning',
+ 'dataset': {
+ 'hf_name': 'dummy-path',
+ 'split': 'train',
+ 'max_seq_len': model_max_length,
+ 'decoder_only_format': True,
+ 'allow_pad_trimming': False,
+ 'packing_ratio': None,
+ 'shuffle': True,
+ },
+ **common_args
+ })
+ monkeypatch.setattr(
+ 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf',
+ lambda *args, **kwargs: [])
+ dl = build_finetuning_dataloader(cfg, gptt, batch_size)
+ elif dataloader_type == 'finetuning-streaming':
+ cfg = DictConfig({
+ 'name': 'finetuning',
+ 'dataset': {
+ 'remote': 'dummy-path',
+ 'local': 'dummy-path',
+ 'split': 'train',
+ 'max_seq_len': model_max_length,
+ 'decoder_only_format': True,
+ 'allow_pad_trimming': False,
+ 'packing_ratio': None,
+ 'shuffle': True,
+ },
+ **common_args
+ })
+ monkeypatch.setattr(
+ 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_streaming',
+ lambda *args, **kwargs: [])
+ dl = build_finetuning_dataloader(cfg, gptt, batch_size)
+ elif dataloader_type == 'text':
+ cfg = DictConfig({
+ 'name': 'text',
+ 'dataset': {
+ 'local': 'dummy-path',
+ 'remote': 'dummy-path',
+ 'split': 'train',
+ 'max_seq_len': model_max_length,
+ 'shuffle': True,
+ 'shuffle_seed': 0,
+ },
+ **common_args
+ })
+ monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset',
+ lambda *args, **kwargs: MagicMock())
+ dl = build_text_dataloader(cfg, gptt, batch_size)
+ elif dataloader_type == 'denoising':
+ cfg = DictConfig({
+ 'name': 'text_denoising',
+ 'dataset': {
+ 'local': 'dummy-path',
+ 'remote': 'dummy-path',
+ 'split': 'val_xsmall',
+ 'shuffle': False,
+ 'max_seq_len': model_max_length,
+ 'packing_ratio': None,
+ 'predownload': 1000,
+ 'keep_zip': False,
+ 'num_workers': None
+ },
+ 'mixture_of_denoisers': {
+ 'decoder_only_format': False,
+ 'span_mean_lengths_and_ratios': [[3, .15], [8, .5]],
+ 'sequence_mask_ratios': 0.25,
+ },
+ **common_args
+ })
+ monkeypatch.setattr('llmfoundry.data.denoising.StreamingTextDataset',
+ lambda *args, **kwargs: MagicMock())
+ dl = build_text_denoising_dataloader(cfg, gptt, batch_size)
+ else:
+ raise NotImplementedError()
+
+ cfg = om.create(cfg)
+
+ actual_token_count = dl.get_num_tokens_in_batch(batch_tokenized)
+
+ assert actual_token_count == expected_token_count
diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py
index c944dcfc97..fcb2cc3a7e 100644
--- a/tests/test_hf_conversion_script.py
+++ b/tests/test_hf_conversion_script.py
@@ -5,8 +5,10 @@
import os
import pathlib
import sys
+from unittest.mock import MagicMock
from composer import Trainer
+from composer.loggers import MLFlowLogger
from composer.utils import dist, get_device
from llmfoundry.callbacks import HuggingFaceCheckpointer
@@ -17,7 +19,7 @@
sys.path.append(repo_dir)
import shutil
from argparse import Namespace
-from typing import cast
+from typing import Optional, cast
import pytest
import torch
@@ -136,6 +138,49 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase,
tokenizer1.__dict__['init_kwargs'].pop('auto_map', None)
tokenizer2.__dict__['init_kwargs'].pop('auto_map', None)
+ # Additional special tokens do not match between original tokenizer and loaded tokenizer due to transformers
+ # constructor differences
+ additional_special_tokens_1 = {
+ t if isinstance(t, str) else t.content
+ for t in tokenizer1.__dict__.pop('_additional_special_tokens', [])
+ }
+ additional_special_tokens_2 = {
+ t if isinstance(t, str) else t.content
+ for t in tokenizer2.__dict__.pop('_additional_special_tokens', [])
+ }
+ # Also pop it out of init_kwargs
+ tokenizer1.__dict__['init_kwargs'].pop('additional_special_tokens', None)
+ tokenizer2.__dict__['init_kwargs'].pop('additional_special_tokens', None)
+ tokenizer1.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
+ tokenizer2.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
+ # If the additional special tokens are the same (or a subset of each other), or if one of them is empty, then we are good
+ assert additional_special_tokens_1.issubset(
+ additional_special_tokens_2) or additional_special_tokens_2.issubset(
+ additional_special_tokens_1)
+
+ # The special token attributes may be strings or they may be AddedToken objects, so we just check string values
+ # First check that they have the same attrs
+ assert tokenizer1.SPECIAL_TOKENS_ATTRIBUTES == tokenizer2.SPECIAL_TOKENS_ATTRIBUTES
+ # Then check that the values are the same
+ for special_token_attr in tokenizer1.SPECIAL_TOKENS_ATTRIBUTES:
+ # Skip additional_special_tokens because we already checked it above
+ if special_token_attr == 'additional_special_tokens':
+ continue
+
+ # The init_kwargs can change between the original tokenizer and the loaded tokenizer,
+ # so we just pop them
+ tokenizer1.__dict__['init_kwargs'].pop(special_token_attr, None)
+ tokenizer2.__dict__['init_kwargs'].pop(special_token_attr, None)
+
+ attr1 = tokenizer1.__dict__.pop('_' + special_token_attr, None)
+ attr2 = tokenizer2.__dict__.pop('_' + special_token_attr, None)
+ if attr1 is None and attr2 is None:
+ continue
+
+ attr_value1 = attr1 if isinstance(attr1, str) else attr1.content
+ attr_value2 = attr2 if isinstance(attr2, str) else attr2.content
+ assert attr_value1 == attr_value2
+
assert tokenizer1.__dict__ == tokenizer2.__dict__
@@ -148,6 +193,23 @@ def check_hf_model_equivalence(model1: PreTrainedModel,
# so we remove it
expected_model_config_dict.pop('_name_or_path')
new_model_config_dict.pop('_name_or_path')
+
+ # Special case a couple of differences that correctly occur when saving MPT to huggingface format
+ # checkpoint
+ architectures_1 = expected_model_config_dict.pop('architectures', None)
+ architectures_2 = new_model_config_dict.pop('architectures', None)
+ if architectures_1 != architectures_2:
+ assert architectures_1 is None and architectures_2 == ['MPTForCausalLM']
+
+ auto_map_1 = expected_model_config_dict.pop('auto_map', None)
+ auto_map_2 = new_model_config_dict.pop('auto_map', None)
+ if auto_map_1 != auto_map_2:
+ assert auto_map_1 == {'AutoConfig': 'configuration_mpt.MPTConfig'}
+ assert auto_map_2 == {
+ 'AutoConfig': 'configuration_mpt.MPTConfig',
+ 'AutoModelForCausalLM': 'modeling_mpt.MPTForCausalLM'
+ }
+
assert expected_model_config_dict == new_model_config_dict
assert all(
torch.equal(p1.cpu(), p2.cpu())
@@ -155,6 +217,10 @@ def check_hf_model_equivalence(model1: PreTrainedModel,
def delete_transformers_cache():
+ # Only delete the files on local rank 0, otherwise race conditions are created
+ if not dist.get_local_rank() == 0:
+ return
+
hf_cache_home = os.path.expanduser(
os.getenv(
'HF_HOME',
@@ -183,9 +249,11 @@ def test_callback_inits_with_defaults():
@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
-@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded'])
+@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
+@pytest.mark.parametrize('log_to_mlflow', [True, False])
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
- fsdp_state_dict_type: str):
+ fsdp_state_dict_type: Optional[str],
+ log_to_mlflow: bool):
delete_transformers_cache()
dist.initialize_dist(get_device('gpu'))
@@ -203,6 +271,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{huggingface_save_interval_batches}ba',
precision=precision_str,
+ mlflow_registered_model_name='dummy-registered-name'
+ if log_to_mlflow else None,
)
# get small version of each model
@@ -324,20 +394,35 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
optimizer = build_optimizer(original_model, optimizer_name,
optimizer_config)
+ mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
+ mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
+ mlflow_logger_mock.save_model = MagicMock()
+ mlflow_logger_mock.register_model = MagicMock()
+ mlflow_logger_mock.model_registry_prefix = ''
trainer = Trainer(
model=original_model,
device='gpu',
- fsdp_config=fsdp_config,
+ fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{save_interval_batches}ba',
max_duration=f'{max_duration_batches}ba',
callbacks=[checkpointer_callback],
+ loggers=[mlflow_logger_mock] if log_to_mlflow else [],
optimizers=optimizer,
save_latest_filename=None,
)
trainer.fit()
+ if dist.get_global_rank() == 0:
+ assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow
+ else 0)
+ assert mlflow_logger_mock.register_model.call_count == (
+ 1 if log_to_mlflow else 0)
+ else:
+ assert mlflow_logger_mock.log_model.call_count == 0
+ assert mlflow_logger_mock.register_model.call_count == 0
+
# summon full params to check equivalence
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with FSDP.summon_full_params(trainer.state.model,
@@ -390,10 +475,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trust_remote_code=True,
)
- check_hf_model_equivalence(trainer.state.model.model.to(precision),
- loaded_model)
+ check_hf_model_equivalence(
+ trainer.state.model.model.to(precision) if fsdp_state_dict_type
+ is not None else trainer.state.model.module.model.to(precision),
+ loaded_model)
check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer)
+ dist.barrier()
delete_transformers_cache()
diff --git a/tests/test_hf_mpt_gen.py b/tests/test_hf_mpt_gen.py
index 68cef14c43..cc357141ba 100644
--- a/tests/test_hf_mpt_gen.py
+++ b/tests/test_hf_mpt_gen.py
@@ -1,16 +1,22 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
+from pathlib import Path
from typing import Any, Dict
+from unittest.mock import Mock
import pytest
+from composer.callbacks import Generate as ComposerGenerate
from composer.core.precision import get_precision_context
+from composer.trainer import Trainer
from composer.utils import get_device, reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from llmfoundry import COMPOSER_MODEL_REGISTRY
+from llmfoundry.data.finetuning import build_finetuning_dataloader
from llmfoundry.utils import build_tokenizer
+from tests.data_utils import make_tiny_ft_dataset
@pytest.mark.gpu
@@ -72,3 +78,90 @@ def test_init_hfhub_mpt(device: str, attn_impl: str):
def test_init_hfhub_mpt_cpu():
test_init_hfhub_mpt(device='cpu', attn_impl='torch')
+
+
+@pytest.mark.gpu
+def test_mpt_generate_callback(tmpdir: Path):
+ composer_device = get_device('gpu')
+ reproducibility.seed_all(42)
+ max_seq_len = 128
+
+ # testing dataset and dataloader
+ dataset_size = 5
+
+ tiny_dataset_path = tmpdir / 'test-ift-data-small'
+ tiny_dataset_path.mkdir()
+ tiny_dataset_file = tiny_dataset_path / 'train.jsonl'
+ make_tiny_ft_dataset(path=str(tiny_dataset_file), size=dataset_size)
+
+ dataloader_cfg = DictConfig({
+ 'name': 'finetuning',
+ 'dataset': {
+ 'hf_name': str(tiny_dataset_path),
+ 'split': 'train',
+ 'max_seq_len': max_seq_len,
+ 'decoder_only_format': True,
+ 'allow_pad_trimming': False,
+ 'packing_ratio': None,
+ 'shuffle': True,
+ },
+ 'drop_last': False,
+ 'num_workers': 4,
+ 'pin_memory': False,
+ 'prefetch_factor': 2,
+ 'persistent_workers': False,
+ 'timeout': 0
+ })
+
+ # build tokenizer
+ tokenizer = build_tokenizer('EleutherAI/gpt-neox-20b', {})
+
+ # build mpt model
+ model_config = DictConfig({
+ 'name': 'mpt_causal_lm',
+ 'config_overrides': {
+ 'd_model': 128,
+ 'n_heads': 4,
+ 'n_layers': 2,
+ 'expansion_ratio': 2,
+ },
+ })
+ model = COMPOSER_MODEL_REGISTRY[model_config.name](model_config, tokenizer)
+ model = composer_device.module_to_device(model)
+
+ # generate callback
+ prompts = [
+ 'The best banana bread recipe is',
+ '2+2=',
+ 'how much wood could a woodchuck chuck',
+ ]
+ gen_interval = 1
+ generate = ComposerGenerate(
+ prompts,
+ interval=f'{gen_interval}ba',
+ max_new_tokens=5,
+ batch_size=len(prompts),
+ use_cache=True,
+ )
+ generate.generate = Mock(wraps=generate.generate, autospec=True)
+
+ # build trainer
+ device_batch_size = 1
+ train_dataloader = build_finetuning_dataloader(
+ dataloader_cfg,
+ tokenizer,
+ device_batch_size,
+ )
+
+ trainer = Trainer(
+ model=model,
+ train_dataloader=train_dataloader,
+ device=composer_device,
+ max_duration=f'{gen_interval}ba',
+ callbacks=[generate],
+ )
+ trainer.logger.log_table = Mock()
+ trainer.fit()
+
+ generate.generate.assert_called_once()
+ trainer.logger.log_table.assert_called_once()
diff --git a/tests/test_huggingface_flash.py b/tests/test_huggingface_flash.py
new file mode 100644
index 0000000000..a71217ea1f
--- /dev/null
+++ b/tests/test_huggingface_flash.py
@@ -0,0 +1,195 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+import contextlib
+import os
+from unittest.mock import patch
+
+import pytest
+import torch
+import transformers
+from composer.core.precision import get_precision_context
+from composer.utils import reproducibility
+from omegaconf import OmegaConf as om
+
+from llmfoundry import COMPOSER_MODEL_REGISTRY
+from llmfoundry.models.hf.hf_fsdp import rgetattr
+from llmfoundry.models.layers.attention import (is_flash_v1_installed,
+ is_flash_v2_installed)
+from llmfoundry.utils.builders import build_tokenizer
+
+# Before importing any transformers models, we need to disable transformers flash attention if
+# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
+# gated import otherwise.
+if is_flash_v1_installed():
+ transformers.utils.is_flash_attn_available = lambda: False
+
+from transformers.models.llama.modeling_llama import LlamaAttention
+
+from llmfoundry.models.layers.llama_attention_monkeypatch import (
+ llama_attention_patch_torch, llama_attention_patch_triton)
+
+
+@pytest.mark.parametrize('patch_fn_name', ['torch', 'triton'])
+@pytest.mark.parametrize('explicit_mask', [True, False])
+@pytest.mark.parametrize(
+ 'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf'])
+@pytest.mark.gpu
+def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
+ model_name: str):
+ if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
+ pytest.skip(
+ 'The CI cluster does not have access to the Llama models, so skip this test.'
+ )
+
+ device = 'cuda:0'
+ sequence_length = 4096
+ model_dim = 4096 if '7b' in model_name else 8192
+ batch_size = 2
+ if patch_fn_name == 'torch':
+ patch_fn = llama_attention_patch_torch
+ dtype = torch.float32
+ atol = 0.0
+ rtol = 0.0
+ elif patch_fn_name == 'triton':
+ # the huggingface implementation of llama performs the softmax in fp32
+ # this can result in fairly large differences for the triton implementation
+ # but the torch implementation produces the exact same output so we can confirm
+ # the implementation is correct
+ patch_fn = llama_attention_patch_triton
+ dtype = torch.bfloat16
+ atol = 1e-2
+ rtol = 1e-2
+ else:
+ raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}')
+
+ llama_config = transformers.AutoConfig.from_pretrained(model_name,
+ use_auth_token=True)
+
+ reproducibility.seed_all(42)
+ attention = LlamaAttention(config=llama_config,)
+ attention.to(dtype=dtype, device=device)
+
+ rng = torch.Generator(device=device).manual_seed(42)
+ hidden_states = torch.randn(batch_size,
+ sequence_length,
+ model_dim,
+ generator=rng,
+ dtype=dtype,
+ device=device)
+ causal_mask = torch.full((sequence_length, sequence_length),
+ torch.finfo(torch.float32).min,
+ device=device)
+ causal_mask = causal_mask.triu(diagonal=1)
+ causal_mask = causal_mask[None,
+ None, :, :].expand(batch_size, 1, sequence_length,
+ sequence_length)
+ attn_output, _, _ = attention(
+ hidden_states=hidden_states,
+ attention_mask=causal_mask if explicit_mask else None,
+ position_ids=None,
+ past_key_value=None,
+ use_cache=False,
+ )
+
+ reproducibility.seed_all(42)
+ with patch.object(LlamaAttention, 'forward', new=patch_fn):
+ attention = LlamaAttention(config=llama_config,)
+ attention.to(dtype=dtype, device=device)
+ new_output, _, _ = attention(
+ hidden_states=hidden_states,
+ attention_mask=causal_mask if explicit_mask else None,
+ position_ids=None,
+ past_key_value=None,
+ use_cache=False,
+ )
+
+ assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)
+
+
+@pytest.mark.gpu
+@pytest.mark.parametrize('model_name', ['llama2', 'mistral'])
+@pytest.mark.parametrize('use_flash_attention_2', [True, False])
+def test_flash2(model_name: str, use_flash_attention_2: bool):
+ if model_name == 'llama2':
+ if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
+ pytest.skip(
+ 'The CI cluster does not have access to the Llama models, so skip this test.'
+ )
+ model_cfg = {
+ 'name': 'hf_causal_lm',
+ 'pretrained_model_name_or_path': 'meta-llama/Llama-2-7b-hf',
+ 'config_overrides': {
+ 'num_hidden_layers': 2,
+ 'intermediate_size': 64,
+ },
+ 'use_auth_token': True,
+ 'pretrained': False,
+ 'init_device': 'cpu',
+ }
+
+ tokenizer_name = 'meta-llama/Llama-2-7b-hf'
+ from transformers.models.llama.modeling_llama import (
+ LlamaAttention, LlamaFlashAttention2)
+ flash_attn_class = LlamaFlashAttention2 if use_flash_attention_2 else LlamaAttention
+ attention_layers_attr = 'model.model.layers'
+ attention_attr = 'self_attn'
+ elif model_name == 'mistral':
+ model_cfg = {
+ 'name': 'hf_causal_lm',
+ 'pretrained_model_name_or_path': 'mistralai/Mistral-7B-v0.1',
+ 'config_overrides': {
+ 'num_hidden_layers': 2,
+ 'intermediate_size': 64,
+ },
+ 'pretrained': False,
+ 'init_device': 'cpu',
+ }
+
+ tokenizer_name = 'mistralai/Mistral-7B-v0.1'
+ from transformers.models.mistral.modeling_mistral import (
+ MistralAttention, MistralFlashAttention2)
+ flash_attn_class = MistralFlashAttention2 if use_flash_attention_2 else MistralAttention
+ attention_layers_attr = 'model.model.layers'
+ attention_attr = 'self_attn'
+ else:
+ raise ValueError(f'Unknown model: {model_name}')
+
+ if use_flash_attention_2:
+ model_cfg['use_flash_attention_2'] = True
+
+ model_cfg = om.create(model_cfg)
+
+ tokenizer = build_tokenizer(
+ tokenizer_name=tokenizer_name,
+ tokenizer_kwargs={'model_max_length': 10},
+ )
+ tokenizer.pad_token = tokenizer.eos_token
+
+ error_context = pytest.raises(
+ ValueError, match='use_flash_attention_2 is set to True'
+ ) if not is_flash_v2_installed(
+ ) and use_flash_attention_2 else contextlib.nullcontext()
+
+ with error_context:
+ model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer)
+
+ # check that it actually used flash attention 2
+ assert model.model.config._flash_attn_2_enabled if use_flash_attention_2 else not model.model.config._flash_attn_2_enabled
+ attention_layer = rgetattr(
+ rgetattr(model, attention_layers_attr)[0], attention_attr)
+ assert isinstance(attention_layer, flash_attn_class)
+
+ tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'],
+ return_tensors='pt',
+ padding=True)
+ tokenized_input['labels'] = tokenized_input['input_ids'].clone()
+
+ tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()}
+ model.to('cuda')
+
+ with get_precision_context('amp_bf16'):
+ # We're just testing that flash attention 2 runs okay
+ outputs = model(tokenized_input)
+ loss = outputs.loss
+ loss.backward()
diff --git a/tests/test_llama_patch.py b/tests/test_llama_patch.py
deleted file mode 100644
index b1cd3711e0..0000000000
--- a/tests/test_llama_patch.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2022 MosaicML LLM Foundry authors
-# SPDX-License-Identifier: Apache-2.0
-
-import os
-
-import pytest
-import torch
-import transformers
-from composer.utils import reproducibility
-from transformers.models.llama.modeling_llama import LlamaAttention
-
-from llmfoundry.models.layers.llama_attention_monkeypatch import (
- llama_attention_patch_torch, llama_attention_patch_triton)
-
-
-@pytest.mark.parametrize('patch_fn_name', ['torch', 'triton'])
-@pytest.mark.parametrize('explicit_mask', [True, False])
-@pytest.mark.parametrize(
- 'model_name', ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-70b-hf'])
-@pytest.mark.gpu
-def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
- model_name: str):
- if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
- pytest.skip(
- 'The CI cluster does not have access to the Llama models, so skip this test.'
- )
-
- original_forward = LlamaAttention.forward
-
- device = 'cuda:0'
- sequence_length = 4096
- model_dim = 4096 if '7b' in model_name else 8192
- batch_size = 2
- if patch_fn_name == 'torch':
- patch_fn = llama_attention_patch_torch
- dtype = torch.float32
- atol = 0.0
- rtol = 0.0
- elif patch_fn_name == 'triton':
- # the huggingface implementation of llama performs the softmax in fp32
- # this can result in fairly large differences for the triton implementation
- # but the torch implementation produces the exact same output so we can confirm
- # the implementation is correct
- patch_fn = llama_attention_patch_triton
- dtype = torch.bfloat16
- atol = 1e-2
- rtol = 1e-2
- else:
- raise ValueError(f'Unknown patch_fn_name: {patch_fn_name}')
-
- llama_config = transformers.AutoConfig.from_pretrained(model_name,
- use_auth_token=True)
-
- reproducibility.seed_all(42)
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
-
- rng = torch.Generator(device=device).manual_seed(42)
- hidden_states = torch.randn(batch_size,
- sequence_length,
- model_dim,
- generator=rng,
- dtype=dtype,
- device=device)
- causal_mask = torch.full((sequence_length, sequence_length),
- torch.finfo(torch.float32).min,
- device=device)
- causal_mask = causal_mask.triu(diagonal=1)
- causal_mask = causal_mask[None,
- None, :, :].expand(batch_size, 1, sequence_length,
- sequence_length)
- attn_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=None,
- past_key_value=None,
- use_cache=False,
- )
-
- reproducibility.seed_all(42)
- LlamaAttention.forward = patch_fn
- attention = LlamaAttention(config=llama_config,)
- attention.to(dtype=dtype, device=device)
- new_output, _, _ = attention(
- hidden_states=hidden_states,
- attention_mask=causal_mask if explicit_mask else None,
- position_ids=None,
- past_key_value=None,
- use_cache=False,
- )
-
- # Reset the forward function so patches don't persist
- LlamaAttention.forward = original_forward
-
- assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)
diff --git a/tests/test_mpt_gen.py b/tests/test_mpt_gen.py
new file mode 100644
index 0000000000..06ddccd479
--- /dev/null
+++ b/tests/test_mpt_gen.py
@@ -0,0 +1,98 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import List, Optional, Tuple
+from unittest.mock import patch
+
+import pytest
+import torch
+from composer.core.precision import get_precision_context
+from composer.utils import dist, get_device, reproducibility
+from omegaconf import DictConfig
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+from llmfoundry import COMPOSER_MODEL_REGISTRY
+from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM
+from llmfoundry.utils import build_tokenizer
+
+EOS_TOKEN_ID = 0
+
+
+class MockMPTForCausalLM(MPTForCausalLM):
+ """Class that overrides the forward of MPTForCausalLM."""
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
+ attention_mask: Optional[torch.ByteTensor] = None,
+ prefix_mask: Optional[torch.ByteTensor] = None,
+ sequence_id: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ use_cache: Optional[bool] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ result = super().forward(input_ids, past_key_values, attention_mask,
+ prefix_mask, sequence_id, labels, return_dict,
+ output_attentions, output_hidden_states,
+ use_cache, inputs_embeds)
+ # Modify the logits to select the next token.
+ if dist.get_global_rank() == 0:
+ # Rank 0 hits EOS immediately.
+ result.logits[:, :, EOS_TOKEN_ID] = torch.inf
+ else:
+ # Other ranks do not hit EOS.
+ result.logits[:, :, EOS_TOKEN_ID] = -torch.inf
+ return result
+
+
+@pytest.mark.world_size(2)
+@pytest.mark.gpu
+@pytest.mark.parametrize('attn_impl', ['triton', 'torch'])
+@pytest.mark.parametrize('use_alibi', [True, False])
+@patch('llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM',
+ new=MockMPTForCausalLM)
+def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool):
+ """Tests mpt generation with mutiple gpus.
+
+ and generations of different lengths.
+ """
+ composer_device = get_device('gpu')
+ dist.initialize_dist(composer_device)
+ reproducibility.seed_all(42)
+
+ model_config = DictConfig({
+ 'name': 'mpt_causal_lm',
+ 'd_model': 128,
+ 'n_heads': 4,
+ 'n_layers': 2,
+ 'expansion_ratio': 2,
+ 'no_bias': False,
+ 'use_cache': True,
+ 'attn_config': {
+ 'attn_impl': attn_impl,
+ 'attn_uses_sequence_id': False,
+ 'alibi': use_alibi
+ },
+ })
+
+ # build tokenizer
+ tokenizer = build_tokenizer('EleutherAI/gpt-neox-20b', {})
+
+ # build model
+ model = COMPOSER_MODEL_REGISTRY[model_config.name](model_config, tokenizer)
+ model = composer_device.module_to_device(model)
+ model.eval()
+
+ model.model = FSDP(model.model)
+
+ with get_precision_context('amp_bf16'):
+ _ = model.generate(composer_device.tensor_to_device(
+ tokenizer('hello', return_tensors='pt')['input_ids']),
+ max_new_tokens=3,
+ eos_token_id=EOS_TOKEN_ID,
+ use_cache=True,
+ synced_gpus=True)
diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py
new file mode 100644
index 0000000000..5b9d45a141
--- /dev/null
+++ b/tests/test_scheduler.py
@@ -0,0 +1,113 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import List
+
+import pytest
+import torch
+from composer.core import State, Time, TimeUnit
+from composer.devices import DeviceCPU, DeviceGPU
+from composer.optim.scheduler import ComposerScheduler
+
+from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler
+
+_MAX_DURATION = '100ba'
+_STEPS_PER_EPOCH = 100
+
+
+@pytest.fixture
+def dummy_schedulers_state(request: pytest.FixtureRequest):
+ device = None
+ for item in request.session.items:
+ device = DeviceCPU(
+ ) if item.get_closest_marker('gpu') is None else DeviceGPU()
+ break
+ assert device != None
+ state = State(
+ model=torch.nn.Linear(5, 5),
+ run_name='run_name',
+ device=device,
+ rank_zero_seed=17,
+ max_duration=_MAX_DURATION,
+ )
+ state.set_dataloader([None] * _STEPS_PER_EPOCH, 'train')
+ return state
+
+
+@pytest.mark.parametrize('scheduler,ssr,test_times,expected_lrs', [
+ pytest.param(
+ InverseSquareRootWithWarmupScheduler(t_warmup='10ba',
+ t_scale='10ba',
+ t_cooldown='0ba',
+ alpha_f_decay=0,
+ alpha_f_cooldown=0), 1.0,
+ ['0ba', '5ba', '10ba', '40ba', '90ba', '100ba'],
+ [0.0, 0.5, 1.0, 0.5, 0.33333, 0.31623]),
+ pytest.param(
+ InverseSquareRootWithWarmupScheduler(t_warmup='20ba',
+ t_scale='2ba',
+ t_cooldown='10ba',
+ alpha_f_decay=0.4,
+ alpha_f_cooldown=0.1), 1.0,
+ ['0ba', '10ba', '20ba', '36ba', '90ba', '95ba', '100ba'],
+ [0.0, 0.5, 1.0, 0.6, 0.5, 0.3, 0.1]),
+])
+def test_scheduler_init(scheduler: ComposerScheduler, ssr: float,
+ test_times: List[str], expected_lrs: List[float],
+ dummy_schedulers_state: State):
+
+ state = dummy_schedulers_state
+ assert state.dataloader_len is not None
+ assert state.max_duration is not None
+ state.max_duration = Time(value=int(state.max_duration.value * ssr),
+ unit=state.max_duration.unit)
+ for test_time, expected_lr in zip(test_times, expected_lrs):
+ parsed_time = Time.from_timestring(test_time)
+ assert parsed_time.unit in [TimeUnit.EPOCH, TimeUnit.BATCH]
+ state.timestamp = state.timestamp.copy(
+ batch=parsed_time,
+ epoch=Time(
+ int(parsed_time) // int(state.dataloader_len), TimeUnit.EPOCH),
+ )
+ lr = scheduler(state, ssr)
+ assert lr == pytest.approx(expected_lr, abs=1e-3)
+
+
+@pytest.mark.parametrize('state_unit,warmup_unit,scale_unit,cooldown_unit', [
+ ['ep', 'ba', 'ba', 'ba'],
+ ['ba', 'ep', 'ep', 'ep'],
+ ['ep', 'ep', 'ba', 'ep'],
+])
+def test_scheduler_units_match_error(state_unit: str, warmup_unit: str,
+ scale_unit: str, cooldown_unit: str,
+ dummy_schedulers_state: State):
+
+ state = dummy_schedulers_state
+ state.max_duration = f'1{state_unit}'
+ scheduler = InverseSquareRootWithWarmupScheduler(
+ t_warmup=f'10{warmup_unit}',
+ t_scale=f'10{scale_unit}',
+ t_cooldown=f'10{cooldown_unit}')
+ with pytest.raises(ValueError, match='does not match'):
+ _ = scheduler(state, 1.0)
+
+
+@pytest.mark.parametrize('warmup_unit,scale_unit,cooldown_unit', [
+ ['dur', 'ba', 'ba'],
+ ['ba', 'dur', 'ba'],
+ ['ba', 'ba', 'dur'],
+])
+def test_unit_dur_error(warmup_unit: str, scale_unit: str, cooldown_unit: str):
+ with pytest.raises(ValueError, match='cannot be in units of "dur".'):
+ _ = InverseSquareRootWithWarmupScheduler(t_warmup=f'1{warmup_unit}',
+ t_scale=f'1{scale_unit}',
+ t_cooldown=f'1{cooldown_unit}')
+
+
+def test_alpha_f_error():
+ with pytest.raises(ValueError, match='alpha_f_decay >= alpha_f_cooldown.'):
+ _ = InverseSquareRootWithWarmupScheduler(t_warmup='10ba',
+ t_scale='10ba',
+ t_cooldown='10ba',
+ alpha_f_decay=0.0,
+ alpha_f_cooldown=0.1)
diff --git a/tests/test_tiktoken.py b/tests/test_tiktoken.py
index a255a5ffa7..85ff18100b 100644
--- a/tests/test_tiktoken.py
+++ b/tests/test_tiktoken.py
@@ -45,14 +45,19 @@
def get_tokenizers_for_testing(
- model_name: Optional[str], encoding_name: Optional[str],
- tmp_path: pathlib.Path
+ model_name: Optional[str],
+ encoding_name: Optional[str],
+ tmp_path: pathlib.Path,
+ add_bos_token: bool = False,
+ add_eos_token: bool = False
) -> Tuple[TiktokenTokenizerWrapper, TiktokenTokenizerWrapper, 'Encoding']:
tiktoken = pytest.importorskip('tiktoken')
# Construction
wrapped_tokenizer = TiktokenTokenizerWrapper(model_name=model_name,
- encoding_name=encoding_name)
+ encoding_name=encoding_name,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token)
if model_name is not None:
original_tokenizer = tiktoken.encoding_for_model(model_name)
else:
@@ -201,3 +206,29 @@ def test_tiktoken_save_from_pretrained(model_name: Optional[str],
model_name, encoding_name, tmp_path)
check_hf_tokenizer_equivalence(wrapped_tokenizer,
reloaded_wrapped_tokenizer)
+
+
+@pytest.mark.parametrize('model_name,encoding_name',
+ MODEL_ENCODING_NAME_PARAMETRIZATION)
+def test_tiktoken_encode_plus(model_name: Optional[str],
+ encoding_name: Optional[str],
+ tmp_path: pathlib.Path):
+ # Testing encode_plus which optionally wrap encodes with bos and eos tokens
+ wrapped_tokenizer, _, _ = get_tokenizers_for_testing(model_name,
+ encoding_name,
+ tmp_path,
+ add_bos_token=True,
+ add_eos_token=True)
+
+ for test_string in TEST_STRINGS:
+ encoded_outputs = wrapped_tokenizer.encode_plus(
+ test_string,
+ add_special_tokens=True,
+ return_special_tokens_mask=True)
+ encoded_input_ids = encoded_outputs.input_ids
+ assert encoded_input_ids[0] == wrapped_tokenizer.bos_token_id
+ assert encoded_input_ids[-1] == wrapped_tokenizer.eos_token_id
+
+ encoded_special_mask = encoded_outputs.special_tokens_mask
+ assert encoded_special_mask[0] == 1
+ assert encoded_special_mask[-1] == 1