diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index bb538dbe9b..bbc5e422d3 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -7,7 +7,7 @@ on: branches: - main paths: - - ./Dockerfile + - Dockerfile - .github/workflows/docker.yaml workflow_dispatch: {} jobs: diff --git a/llmfoundry/models/inference_api_wrapper/__init__.py b/llmfoundry/models/inference_api_wrapper/__init__.py index 496abf2aa6..9bb2ece2b2 100644 --- a/llmfoundry/models/inference_api_wrapper/__init__.py +++ b/llmfoundry/models/inference_api_wrapper/__init__.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.models.inference_api_wrapper.fmapi import ( + FMAPICasualLMEvalWrapper, FMAPIChatAPIEvalWrapper) from llmfoundry.models.inference_api_wrapper.interface import \ InferenceAPIEvalWrapper from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( @@ -10,4 +12,6 @@ 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', 'InferenceAPIEvalWrapper', + 'FMAPICasualLMEvalWrapper', + 'FMAPIChatAPIEvalWrapper', ] diff --git a/llmfoundry/models/inference_api_wrapper/fmapi.py b/llmfoundry/models/inference_api_wrapper/fmapi.py new file mode 100644 index 0000000000..867b3c272e --- /dev/null +++ b/llmfoundry/models/inference_api_wrapper/fmapi.py @@ -0,0 +1,72 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import time +from typing import Dict + +import requests +from transformers import AutoTokenizer + +from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface) + +__all__ = [ + 'FMAPICasualLMEvalWrapper', + 'FMAPIChatAPIEvalWrapper', +] + +log = logging.getLogger(__name__) + + +def block_until_ready(base_url: str): + """Block until the endpoint is ready.""" + sleep_s = 5 + timout_s = 5 * 60 # At max, wait 5 minutes + + ping_url = f'{base_url}/ping' + + waited_s = 0 + while True: + try: + requests.get(ping_url) + log.info(f'Endpoint {ping_url} is ready') + break + except requests.exceptions.ConnectionError: + log.debug( + f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds' + ) + time.sleep(sleep_s) + waited_s += sleep_s + + if waited_s >= timout_s: + raise TimeoutError( + f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting' + ) + + +class FMAPIEvalInterface(OpenAIEvalInterface): + + def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer): + is_local = model_cfg.pop('local', False) + if is_local: + base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT', + 'http://0.0.0.0:8080/v2') + model_cfg['base_url'] = base_url + block_until_ready(base_url) + + if 'base_url' not in model_cfg: + raise ValueError( + 'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper' + ) + + super().__init__(model_cfg, tokenizer) + + +class FMAPICasualLMEvalWrapper(FMAPIEvalInterface, OpenAICausalLMEvalWrapper): + """Databricks Foundational Model API wrapper for causal LM models.""" + + +class FMAPIChatAPIEvalWrapper(FMAPIEvalInterface, OpenAIChatAPIEvalWrapper): + """Databricks Foundational Model API wrapper for chat models.""" diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index 39de2ba59c..587dd179bd 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -36,9 +36,6 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper): def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: super().__init__(model_cfg, tokenizer) - assert os.getenv( - 'OPENAI_API_KEY' - ) is not None, 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' try: import openai except ImportError as e: @@ -46,8 +43,28 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: extra_deps_group='openai', conda_package='openai', conda_channel='conda-forge') from e - self.client = openai.OpenAI() - self.model_name = model_cfg['version'] + + api_key = os.environ.get('OPENAI_API_KEY') + base_url = model_cfg.get('base_url') + if base_url is None: + # Using OpenAI default, where the API key is required + if api_key is None: + raise ValueError( + 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' + ) + + else: + # Using a custom base URL, where the API key may not be required + log.info( + f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}' + ) + api_key = 'placeholder' # This cannot be None + + self.client = openai.OpenAI(base_url=base_url, api_key=api_key) + if 'version' in model_cfg: + self.model_name = model_cfg['version'] + else: + self.model_name = model_cfg['name'] def generate_completion(self, prompt: str, num_tokens: int): raise NotImplementedError() diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index fecd79553f..89f861c3f0 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -516,6 +516,7 @@ def __init__( attn_impl: str = 'triton', clip_qkv: Optional[float] = None, qk_ln: bool = False, + qk_gn: bool = False, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -529,6 +530,7 @@ def __init__( self.attn_impl = attn_impl self.clip_qkv = clip_qkv self.qk_ln = qk_ln + self.qk_gn = qk_gn self.d_model = d_model self.n_heads = n_heads @@ -549,6 +551,8 @@ def __init__( raise ValueError( 'Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.' ) + if qk_ln and qk_gn: + raise ValueError('Only one of qk_ln and qk_gn can be set to True.') self.softmax_scale = softmax_scale if self.softmax_scale is None: @@ -572,11 +576,13 @@ def __init__( ] self.Wqkv._fused = (0, fuse_splits) - if self.qk_ln: + if self.qk_ln or self.qk_gn: norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] - self.q_ln = norm_class(self.d_model, device=device) - self.k_ln = norm_class(self.kv_n_heads * self.head_dim, - device=device) + norm_size = self.head_dim if qk_gn else d_model + self.q_ln = norm_class(norm_size, device=device) + if qk_ln: + norm_size = self.head_dim * kv_n_heads + self.k_ln = norm_class(norm_size, device=device) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn @@ -623,11 +629,16 @@ def forward( key_padding_mask = attention_mask - if self.qk_ln: + if self.qk_ln or self.qk_gn: # Applying layernorm to qk + q_shape, k_shape = query.shape, key.shape + if self.qk_gn: + b, s = query.shape[:2] + query = query.view(b, s, self.n_heads, -1) + key = key.view(b, s, self.kv_n_heads, -1) dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) + query = self.q_ln(query).to(dtype).view(q_shape) + key = self.k_ln(key).to(dtype).view(k_shape) if rotary_emb_w_meta_info is not None: rotary_emb = rotary_emb_w_meta_info['rotary_emb'] @@ -712,6 +723,7 @@ def __init__( attn_impl: str = 'triton', clip_qkv: Optional[float] = None, qk_ln: bool = False, + qk_gn: bool = False, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -727,6 +739,7 @@ def __init__( attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, + qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, @@ -751,6 +764,7 @@ def __init__( attn_impl: str = 'triton', clip_qkv: Optional[float] = None, qk_ln: bool = False, + qk_gn: bool = False, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -766,6 +780,7 @@ def __init__( attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, + qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 036a4e7cd2..4ac43a8bac 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -22,6 +22,7 @@ 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, + 'qk_gn': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, diff --git a/llmfoundry/models/model_registry.py b/llmfoundry/models/model_registry.py index be09a69835..ff9942f5f6 100644 --- a/llmfoundry/models/model_registry.py +++ b/llmfoundry/models/model_registry.py @@ -3,7 +3,9 @@ from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM, ComposerHFT5) -from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper, +from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper, + FMAPIChatAPIEvalWrapper, + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper) from llmfoundry.models.mpt import ComposerMPTCausalLM @@ -13,5 +15,7 @@ 'hf_prefix_lm': ComposerHFPrefixLM, 'hf_t5': ComposerHFT5, 'openai_causal_lm': OpenAICausalLMEvalWrapper, - 'openai_chat': OpenAIChatAPIEvalWrapper + 'fmapi_causal_lm': FMAPICasualLMEvalWrapper, + 'openai_chat': OpenAIChatAPIEvalWrapper, + 'fmapi_chat': FMAPIChatAPIEvalWrapper, } diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 5474529277..7911728397 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -82,6 +82,7 @@ def __init__( attn_pdrop (float): The dropout probability for the attention layers. attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. + qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer. clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to this value. softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 75438b895e..29642381f8 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -219,21 +219,16 @@ def build_callback( def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: - kwargs_dict = { - k: v if isinstance(v, str) else om.to_container(v, resolve=True) - for k, v in kwargs.items() - } - if name == 'wandb': - return WandBLogger(**kwargs_dict) + return WandBLogger(**kwargs) elif name == 'tensorboard': - return TensorboardLogger(**kwargs_dict) + return TensorboardLogger(**kwargs) elif name == 'in_memory_logger': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) elif name == 'mlflow': - return MLFlowLogger(**kwargs_dict) + return MLFlowLogger(**kwargs) elif name == 'inmemory': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) else: raise ValueError(f'Not sure how to build logger: {name}') diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index 5d8a413d91..07c84a85c8 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -30,6 +30,12 @@ ] PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*' SAFE_WEIGHTS_PATTERN = 'model*.safetensors*' +TOKENIZER_FILES = [ + 'special_tokens_map.json', + 'tokenizer.json', + 'tokenizer.model', + 'tokenizer_config.json', +] ORAS_PASSWD_PLACEHOLDER = '' ORAS_CLI = 'oras' @@ -45,6 +51,7 @@ def download_from_hf_hub( model: str, save_dir: str, prefer_safetensors: bool = True, + tokenizer_only: bool = False, token: Optional[str] = None, ): """Downloads model files from a Hugging Face Hub model repo. @@ -57,6 +64,7 @@ def download_from_hf_hub( save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. + tokenizer_only (bool): If true, only download tokenizer files. token (str, optional): The HuggingFace API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` environment variable. @@ -95,10 +103,13 @@ def download_from_hf_hub( ' Please make sure the repo contains either safetensors or pytorch weights.' ) + allow_patterns = TOKENIZER_FILES if tokenizer_only else None + download_start = time.time() hf_hub.snapshot_download(model, local_dir=save_dir, ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns, token=token) download_duration = time.time() - download_start log.info( @@ -221,16 +232,18 @@ def download_from_oras(model: str, config_file: str, credentials_dir: str, save_dir: str, + tokenizer_only: bool = False, concurrency: int = 10): """Download from an OCI-compliant registry using oras. Args: - model: The name of the model to download. - config_file: Path to a YAML config file that maps model names to registry paths. - credentials_dir: Path to a directory containing credentials for the registry. It is expected to contain three + model (str): The name of the model to download. + config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths. + credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three files: `username`, `password`, and `registry`, each of which contains the corresponding credential. - save_dir: Path to the directory where files will be downloaded. - concurrency: The number of concurrent downloads to run. + save_dir (str): Path to the directory where files will be downloaded. + tokenizer_only (bool): If true, only download the tokenzier files. + concurrency (int): The number of concurrent downloads to run. """ if shutil.which(ORAS_CLI) is None: raise Exception( @@ -253,7 +266,8 @@ def _read_secrets_file(secret_file_path: str,): with open(config_file, 'r', encoding='utf-8') as f: configs = yaml.safe_load(f.read()) - path = configs['models'][model] + config_type = 'tokenizers' if tokenizer_only else 'models' + path = configs[config_type][model] registry = secrets['registry'] def get_oras_cmd(username: Optional[str] = None, diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index 2218e575b2..bfd60b8ee1 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -385,6 +385,10 @@ def convert_text_to_mds( local_output_folder = tempfile.TemporaryDirectory( ).name if is_remote_output else output_folder + if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: + raise FileExistsError( + f'{output_folder=} is not empty. Please remove or empty it.') + if processes > 1: # Download and convert the text files in parallel args = get_task_args(object_names, local_output_folder, input_folder, diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py index 1913267e20..13a63ce55e 100644 --- a/scripts/misc/download_model.py +++ b/scripts/misc/download_model.py @@ -7,10 +7,11 @@ python download_model.py hf --model mosaicml/mpt-7b --save-dir --token Download from ORAS registry: - python download_model.py oras --registry --path mosaicml/mpt-7b --save-dir + python download_model.py oras --model mosaicml/mpt-7b --config-file \ + --credentials-dir --save-dir Download from an HTTP file server: - python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir + python download_model.py http --url https://server.com/models/mosaicml/mpt-7b/ --save-dir Download from an HTTP file server with fallback to Hugging Face Hub: python download_model.py http --host https://server.com --path mosaicml/mpt-7b --save-dir \ @@ -56,6 +57,9 @@ def parse_args() -> argparse.Namespace: base_parser = argparse.ArgumentParser(add_help=False) base_parser.add_argument('--save-dir', type=str, required=True) + base_parser.add_argument('--tokenizer-only', + default=False, + action='store_true') # Add subparser for downloading from Hugging Face Hub. hf_parser = subparsers.add_parser('hf', parents=[base_parser]) @@ -85,6 +89,9 @@ def parse_args() -> argparse.Namespace: download_from = args.download_from if download_from == 'http': + if args.tokenizer_only: + raise ValueError( + 'tokenizer-only is not currently supported for http.') try: download_from_http_fileserver(args.url, args.save_dir, args.ignore_cert) @@ -109,7 +116,12 @@ def parse_args() -> argparse.Namespace: download_from_hf_hub(args.model, save_dir=args.save_dir, token=args.token, + tokenizer_only=args.tokenizer_only, prefer_safetensors=args.prefer_safetensors) elif download_from == 'oras': - download_from_oras(args.model, args.config_file, args.credentials_dir, - args.save_dir, args.concurrency) + download_from_oras(args.model, + args.config_file, + args.credentials_dir, + args.save_dir, + tokenizer_only=args.tokenizer_only, + concurrency=args.concurrency) diff --git a/scripts/train/train.py b/scripts/train/train.py index ba703ad60b..32ddd54fda 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -242,7 +242,8 @@ def main(cfg: DictConfig) -> Trainer: logger_configs: Optional[DictConfig] = pop_config(cfg, 'loggers', must_exist=False, - default_value=None) + default_value=None, + convert=True) callback_configs: Optional[DictConfig] = pop_config(cfg, 'callbacks', must_exist=False, diff --git a/setup.py b/setup.py index 1307e3a024..215b72a868 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ install_requires = [ 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.2,<0.19', # TEMPORARY 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.36,<4.37', + 'transformers>=4.37,<4.38', 'mosaicml-streaming>=0.7.2,<0.8', # 'torch>=2.1,<2.4', # TEMPORARY 'datasets>=2.16,<2.17', diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index cc293a2cdd..3a00a8889f 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -3,6 +3,7 @@ import os import pathlib +import shutil from concurrent.futures import ProcessPoolExecutor from glob import glob from typing import Callable, Iterable, List @@ -55,23 +56,6 @@ def upload_object(self, object_name: str, filename: str): remote_file.write(local_file.read()) -def _call_convert_text_to_mds(processes: int, tokenizer_name: str, - concat_tokens: int) -> None: - convert_text_to_mds( - tokenizer_name=tokenizer_name, - output_folder=f's3://fake-test-output-path', - input_folder=f's3://fake-test-input-path', - concat_tokens=concat_tokens, - eos_text='', - bos_text='', - no_wrap=False, - compression='zstd', - processes=processes, - args_str='Namespace()', - reprocess=False, - ) - - # Mock starmap with no multiprocessing def _mock_map(func: Callable, args: Iterable) -> Iterable: for arg in args: @@ -107,9 +91,22 @@ def test_single_and_multi_process(merge_shard_groups: Mock, maybe_create_object_store_from_uri.return_value = mock_object_store parse_uri.return_value = ('s3', 'fake-test-bucket', str(remote_folder)) - _call_convert_text_to_mds(processes=processes, - tokenizer_name=tokenizer_name, - concat_tokens=concat_tokens) + def call_convert_text_to_mds() -> None: + convert_text_to_mds( + tokenizer_name=tokenizer_name, + output_folder=f's3://fake-test-output-path', + input_folder=f's3://fake-test-input-path', + concat_tokens=concat_tokens, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=processes, + args_str='Namespace()', + reprocess=False, + ) + + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes # called once per process @@ -131,9 +128,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock, _assert_files_exist(prefix=remote_folder, files=['index.json', DONE_FILENAME] + shards) - _call_convert_text_to_mds(processes=processes, - tokenizer_name=tokenizer_name, - concat_tokens=concat_tokens) + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes # No changes because we shoudn't reprocess @@ -146,9 +141,7 @@ def test_single_and_multi_process(merge_shard_groups: Mock, mock_object_store = Mock(wraps=object_store) maybe_create_object_store_from_uri.return_value = mock_object_store - _call_convert_text_to_mds(processes=processes, - tokenizer_name=tokenizer_name, - concat_tokens=concat_tokens) + call_convert_text_to_mds() # Check call counts assert download_and_convert.call_count == processes * 2 # called once per process @@ -187,31 +180,42 @@ def test_local_path(tmp_path: pathlib.Path): input_folder = tmp_path / 'input' output_folder = tmp_path / 'output' + def call_convert_text_to_mds(reprocess: bool): + convert_text_to_mds( + tokenizer_name='mosaicml/mpt-7b', + output_folder=str(output_folder), + input_folder=str(input_folder), + concat_tokens=1, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=1, + args_str='Namespace()', + reprocess=reprocess, + ) + # Create input text data os.makedirs(input_folder, exist_ok=True) with open(input_folder / 'test.txt', 'w') as f: f.write('test') # Convert text data to mds - convert_text_to_mds( - tokenizer_name='mosaicml/mpt-7b', - output_folder=str(output_folder), - input_folder=str(input_folder), - concat_tokens=1, - eos_text='', - bos_text='', - no_wrap=False, - compression='zstd', - processes=1, - args_str='Namespace()', - reprocess=False, - ) + call_convert_text_to_mds(reprocess=False) # Make sure all the files exist as expected. assert os.path.exists(output_folder / '.text_to_mds_conversion_done') assert os.path.exists(output_folder / 'index.json') assert os.path.exists(output_folder / 'shard.00000.mds.zstd') + # Test reprocessing. + with pytest.raises(FileExistsError): + call_convert_text_to_mds(reprocess=True) + + shutil.rmtree(output_folder) + + call_convert_text_to_mds(reprocess=True) + def test_is_already_processed(tmp_path: pathlib.Path): tmp_path_str = str(tmp_path) diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 2f992cd92f..d409486cc6 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -28,7 +28,11 @@ def allclose_helper(t0: torch.Tensor, ('triton', 'torch'), ]) @pytest.mark.parametrize('clip_qkv', [True, False]) -@pytest.mark.parametrize('qk_ln', [True, False]) +@pytest.mark.parametrize('qk_ln, qk_gn', [ + (True, False), + (False, True), + (False, False), +]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -64,6 +68,7 @@ def test_attn_impl(attn_impl_0: str, attn_impl_1: str, clip_qkv: bool, qk_ln: bool, + qk_gn: bool, pos_emb_config: dict, attn_type: str, attn_uses_sequence_id: bool, @@ -71,8 +76,8 @@ def test_attn_impl(attn_impl_0: str, device: str = 'cuda'): """Compare all attn impl with each other. - Includes testing with and without attn_clip_qkv, attn_qk_ln, alibi, and - rope. + Includes testing with and without attn_clip_qkv, attn_qk_ln, attn_qk_gn, + alibi, and rope. """ alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] @@ -100,6 +105,7 @@ def test_attn_impl(attn_impl_0: str, 'attn_pdrop': 0, 'clip_qkv': clip_qkv, 'qk_ln': qk_ln, + 'qk_gn': qk_gn, }) n, s, f = 2, 4, cfg.d_model @@ -269,7 +275,8 @@ def gen_bias(attn_impl: str): 'rope_impl'] == 'hf' # special case that (likely) fails due to numerics - if clip_qkv and qk_ln and using_hf_rope and attn_type == 'grouped_query_attention': + if (clip_qkv and (qk_ln or qk_gn) and using_hf_rope and + attn_type == 'grouped_query_attention'): assert allclose_helper(p.grad, tp.grad, atol=2.e-2, rtol=2.e-2) else: assert allclose_helper(p.grad, tp.grad) diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 9be6630075..303afc9b7d 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -135,14 +135,14 @@ def test_build_logger(): with pytest.raises(ValueError): _ = build_logger('unknown', {}) - logger_cfg = DictConfig({ + logger_cfg = { 'project': 'foobar', 'init_kwargs': { 'config': { 'foo': 'bar', } } - }) + } wandb_logger = build_logger('wandb', logger_cfg) # type: ignore assert isinstance(wandb_logger, WandBLogger) assert wandb_logger.project == 'foobar' diff --git a/tests/utils/test_model_download_utils.py b/tests/utils/test_model_download_utils.py index 471a39dcdb..14749bdcd9 100644 --- a/tests/utils/test_model_download_utils.py +++ b/tests/utils/test_model_download_utils.py @@ -110,6 +110,7 @@ def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock, mock_snapshot_download.assert_called_once_with( test_repo_id, local_dir=save_dir, + allow_patterns=None, ignore_patterns=expected_ignore_patterns, token=None)