Skip to content

Commit

Permalink
merged main
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Sep 27, 2023
2 parents 7ad4c8d + ca729a5 commit 5d9d824
Show file tree
Hide file tree
Showing 19 changed files with 689 additions and 34 deletions.
11 changes: 8 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,14 @@ def _build_hf_dataset_from_remote(
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores in the destination split.
destination_split = cfg.dataset.split.replace('-', '_')
finetune_dir = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning',
cfg.dataset.split if cfg.dataset.split != 'data' else 'data_not',
destination_split if destination_split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
Expand All @@ -306,10 +309,12 @@ def _build_hf_dataset_from_remote(
os.path.abspath(
os.path.join(
finetune_dir, 'data',
f'{cfg.dataset.split}-00000-of-00001.{extension}')))
f'{destination_split}-00000-of-00001.{extension}')))

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed')
signal_file_path = os.path.join(
finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed')
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def build_from_hf(
Dataset: The tokenized dataset.
"""
dataset_name = cfg.hf_name
split = cfg.split
# HF datasets does not support a split with dashes,so we replace split
# dashes with underscore.
split = cfg.split.replace('-', '_')
kwargs = cfg.get('hf_kwargs', {})
proto_preprocessing_fn = cfg.get('preprocessing_fn')
if isinstance(proto_preprocessing_fn, dict) or isinstance(
Expand Down
6 changes: 4 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

# required for loading a python model into composer
import transformers
from composer.metrics.nlp import (InContextLearningLMAccuracy,
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(self, om_model_config: Union[DictConfig,
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
Expand Down Expand Up @@ -164,7 +166,7 @@ def __init__(self, om_model_config: Union[DictConfig,
f'init_device="{init_device}" must be either "cpu" or "meta".'
)

signal_file_path = '.local_rank0_completed_autoresume'
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')
Expand Down
60 changes: 47 additions & 13 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
import warnings
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -31,6 +31,23 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
return original_is_causal


def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Perform repeat of kv heads along a particular dimension.
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
n_rep: amount of repetitions of kv_n_heads
Unlike torch.repeat_interleave, this function avoids allocating new memory.
"""
if n_rep == 1:
return hidden

b, s, kv_n_heads, d = hidden.shape

hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)

return hidden.reshape(b, s, kv_n_heads * n_rep, d)


def scaled_multihead_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -84,8 +101,11 @@ def scaled_multihead_dot_product_attention(

# grouped query case
if kv_n_heads > 1 and kv_n_heads < n_heads:
k = k.repeat_interleave(n_heads // kv_n_heads, dim=1)
v = v.repeat_interleave(n_heads // kv_n_heads, dim=1)
# necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function
k = repeat_kv_for_gqa(k.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)
v = repeat_kv_for_gqa(v.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)

if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
Expand Down Expand Up @@ -243,10 +263,16 @@ def flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along the head dimension = 1
key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1)
value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads,
dim=1)

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -383,9 +409,8 @@ def triton_flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along dim = 2, unlike the implementation for flash and torch attn
key = key.repeat_interleave(n_heads // kv_n_heads, dim=2)
value = value.repeat_interleave(n_heads // kv_n_heads, dim=2)
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func( # type: ignore
Expand Down Expand Up @@ -419,6 +444,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -450,7 +476,9 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

fc_kwargs = {}
fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
Expand Down Expand Up @@ -557,6 +585,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -569,7 +598,9 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device)
device=device,
bias=bias,
)


class MultiQueryAttention(GroupedQueryAttention):
Expand All @@ -591,6 +622,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -603,7 +635,9 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device)
device=device,
bias=bias,
)


def attn_bias_shape(
Expand Down
15 changes: 10 additions & 5 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
no_bias: bool = False,
**kwargs: Any,
):
if attn_config is None:
Expand Down Expand Up @@ -66,11 +67,14 @@ def __init__(
}

self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
Expand All @@ -79,6 +83,7 @@ def __init__(
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
**ffn_config,
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
Expand Down
8 changes: 7 additions & 1 deletion llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ def __init__(
expansion_ratio: int,
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()
fc_kwargs = {}
fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
self.up_proj = FC_CLASS_REGISTRY[fc_type](
Expand Down Expand Up @@ -60,6 +63,7 @@ def build_ffn(
expansion_ratio: int,
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
**kwargs: Any,
) -> nn.Module:
ffn_type = kwargs.pop('ffn_type')
Expand All @@ -72,12 +76,14 @@ def build_ffn(
expansion_ratio=expansion_ratio,
fc_type=fc_type,
device=device,
bias=bias,
)
elif ffn_type == 'te_ln_mlp':
assert te is not None
return te.LayerNormMLP(
hidden_size=d_model,
ffn_hidden_size=d_model * expansion_ratio,
bias=bias,
**kwargs,
)

Expand Down
9 changes: 8 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.metrics import (InContextLearningLMAccuracy,
from composer.metrics import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -150,6 +151,11 @@ def __init__(self, config: MPTConfig):
log.info(f'Removing bias ({module.bias}) from {module}.')
module.register_parameter('bias', None)

# For transformer engine
if hasattr(module, 'use_bias'):
log.info(f'Setting use_bias=False for {module}.')
module.use_bias = False

log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

Expand Down Expand Up @@ -695,6 +701,7 @@ def __init__(
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError(),
]
Expand Down
8 changes: 8 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def _validate_cfg(icl_cfg: DictConfig):
]
elif icl_cfg.icl_task_type == 'question_answering':
icl_cfg.metric_names = ['InContextLearningQAAccuracy']
elif icl_cfg.icl_task_type == 'code_evaluation':
icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
else:
raise ValueError(
f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.'
Expand All @@ -244,6 +246,10 @@ def _validate_cfg(icl_cfg: DictConfig):
icl_cfg.max_seq_len = default_max_seq_len
if 'batch_size' not in icl_cfg:
icl_cfg.batch_size = default_batch_size
if 'pass_at_k' not in icl_cfg:
icl_cfg.pass_at_k = 1
if 'num_beams' not in icl_cfg:
icl_cfg.num_beams = 20

for icl_cfg in icl_tasks_list:
assert isinstance(icl_cfg, DictConfig)
Expand Down Expand Up @@ -274,6 +280,8 @@ def _validate_cfg(icl_cfg: DictConfig):
example_delimiter=icl_cfg.example_delimiter,
continuation_delimiter=icl_cfg.continuation_delimiter,
destination_path=destination_path,
pass_at_k=icl_cfg.pass_at_k,
generations_per_sample=icl_cfg.num_beams,
has_categories=icl_cfg.get('has_categories', False),
)
if hasattr(
Expand Down
35 changes: 30 additions & 5 deletions scripts/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,13 @@ This document explains the ICL formats compatible with [Composer](https://github

## Supported ICL formats

Composer currently supports four ICL formats
Composer currently supports five ICL formats:

1. [InContextLearningQATaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L92-L253)
2. [InContextLearningLMTaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L256-L402)
3. [InContextLearningMultipleChoiceTaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L405-L599)
4. [InContextLearningSchemaTaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L602-L773)
1. [InContextLearningQATaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L103)
2. [InContextLearningLMTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L293)
3. [InContextLearningMultipleChoiceTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L444)
4. [InContextLearningSchemaTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L676)
5. [InContextLearningCodeEvalDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L852)

----

Expand Down Expand Up @@ -346,6 +347,30 @@ Below is a YAML section that works with the Winograd dataset in [`scripts/eval/l
continuation_delimiter: ' ' # this separates questions from answers
>

----

### InContextLearningCodeEvalDataset

The ICL CodeEvalDataset takes a prompt, and, working with the NLP metric [InContextLearningCodeEvalAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningCodeEvalAccuracy.html), generates code which gets run against the supplied tests, as in HumanEval ([Evaluating Large Language Models Trained on Code](https://arxiv.org/abs/2107.03374)) and MBPP ([Program Synthesis with Large Language Models](https://arxiv.org/abs/2108.07732)). This generation involves many decoding steps, so can take longer per sample than other ICL tasks. An example datum:

```json
{"task_id": "JavaScript/2", "prompt": "/* Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncateNumber(3.5)\n 0.5\n */\nconst truncateNumber = (number) => {\n", "canonical_solution": " return number % 1.0;\n}\n\n", "test": "const testTruncateNumber = () => {\n console.assert(truncateNumber(3.5) === 0.5)\n\n console.assert(Math.abs(truncateNumber(1.33) - 0.33) < 1e-6)\n\n console.assert(Math.abs(truncateNumber(123.456 - 0.456) < 1e-6))\n}\n\ntestTruncateNumber()\n", "entry_point": "truncateNumber", "test_inputs": ["3.5", "1.33", "123.456"], "test_outputs": ["0.5", "0.33", "0.456"], "language": "javascript"}
```

Required keys for each datum:

* `prompt: str`
* `test: str`
* `entry_point: str`
* `test_inputs: List[str]`
* `test_outputs: List[str]`
* `language: str`

Code evaluation can happen locally (insecure) or inside an AWS Lambda function sandbox. This is controlled by setting the environment variable `CODE_EVAL_DEVICE` to `LOCAL` or `LAMBDA`. If set to `LAMBDA`, you must also provide `CODE_EVAL_URL` and `CODE_EVAL_APIKEY` to query the API gateway in the AWS Sandbox.

----

### Build your own dataset (BYOD)
Building a dataset compatible with our eval suite is very easy if it fits with one of the four supported task types. Simply choose the appropriate task type (LM, MC, QA, or Schema) and process each dataset into a jsonl format in which each row has the format described above.

Expand Down
Loading

0 comments on commit 5d9d824

Please sign in to comment.