diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 87a08a999d..4d72f19cbc 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -55,6 +55,17 @@ import_file, save_registry, ) +from llmfoundry.utils.validation_notebook_utils import ( + check_HF_datasets, + convert_text_to_mds, + create_om_cfg, + integrity_check, + is_hf_dataset_path, + is_uc_delta_table, + parse_args, + plot_hist, + token_counts_with_collate, +) from llmfoundry.utils.warnings import ( ExperimentalWarning, VersionedDeprecationWarning, @@ -111,4 +122,13 @@ 'ExperimentalWarning', 'experimental_function', 'experimental_class', + 'check_HF_datasets', + 'convert_text_to_mds', + 'create_om_cfg', + 'integrity_check', + 'is_hf_dataset_path', + 'is_uc_delta_table', + 'parse_args', + 'token_counts_with_collate', + 'plot_hist', ] diff --git a/llmfoundry/utils/validation_notebook_utils.py b/llmfoundry/utils/validation_notebook_utils.py new file mode 100644 index 0000000000..b7e3fd25a5 --- /dev/null +++ b/llmfoundry/utils/validation_notebook_utils.py @@ -0,0 +1,531 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import os +import re +import tempfile +from argparse import Namespace +from concurrent.futures import ProcessPoolExecutor +from typing import Any, Mapping, Optional, Tuple, Union + +import torch +from composer.utils import ( + maybe_create_object_store_from_uri, +) +from datasets import get_dataset_split_names +from huggingface_hub import dataset_info +from omegaconf import OmegaConf as om +from streaming.base.storage.download import download_file +from streaming.base.storage.upload import CloudUploader +from transformers import AutoTokenizer + +from llmfoundry.command_utils.data_prep.convert_text_to_mds import ( + ConcatTokensFromFilesDataset, + get_object_names, + get_task_args, + is_already_processed, + is_remote_path, +) +from llmfoundry.utils import build_tokenizer +from llmfoundry.utils.data_prep_utils import ( + DownloadingIterable, + merge_shard_groups, +) +from llmfoundry.utils.exceptions import ( + InputFolderMissingDataError, + OutputFolderNotEmptyError, +) + +log = logging.getLogger(__name__) + + +def create_om_cfg(FT_API_args: Namespace): + task_type = FT_API_args.task_type + + train_data_path = FT_API_args.train_data_path + split = 'train' + + if is_hf_dataset_path(FT_API_args.train_data_path): + train_data_path, split = '/'.join( + FT_API_args.train_data_path.split('/')[:2], + ), FT_API_args.train_data_path.split('/')[-1] + + model = FT_API_args.model + max_seq_len = FT_API_args.context_length + detected_cpu_count = os.cpu_count() or 1 + + common_args = { + 'drop_last': False, + 'num_workers': detected_cpu_count, + 'prefetch_factor': 2, + 'pin_memory': False, + 'persistent_workers': False, + 'timeout': 0, + } + if task_type == 'INSTRUCTION_FINETUNE' or task_type == 'CHAT_COMPLETION': + cfg = om.create({ + 'dataset': { + 'hf_name': train_data_path, + 'split': split, + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'shuffle': True, + }, + **common_args, + }) + + else: + cfg = om.create({ + 'name': 'finetuning', + 'dataset': { + 'remote': train_data_path, + 'local': train_data_path, + 'split': split, + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + **common_args, + }) + + tokenizer = build_tokenizer( + tokenizer_name=model, + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + return cfg, tokenizer + + +def token_counts_with_collate(FT_API_args: Namespace): + from llmfoundry import registry + from llmfoundry.data.finetuning import build_finetuning_dataloader + from llmfoundry.utils.registry_utils import construct_from_registry + + cfg, tokenizer = create_om_cfg(FT_API_args) + dataloader = build_finetuning_dataloader( + **cfg, + tokenizer=tokenizer, + device_batch_size=1, + ).dataloader + + detected_cpu_count = os.cpu_count() or 1 + num_cpus_to_use = max(1, detected_cpu_count) + cfg.num_workers = num_cpus_to_use + + dataloader_cfg = { + 'name': 'finetuning', + 'dataset': cfg.dataset, + 'drop_last': cfg.drop_last, + 'num_workers': cfg.num_workers, + 'pin_memory': cfg.pin_memory, + 'prefetch_factor': cfg.prefetch_factor, + 'persistent_workers': cfg.persistent_workers, + 'timeout': cfg.timeout, + } + collate_fn, _ = construct_from_registry( + name='finetuning_collator', + registry=registry.collators, + partial_function=False, + kwargs={ + 'dataloader_cfg': dataloader_cfg, + 'tokenizer': tokenizer, + 'dataset_batch_size': 1, + }, + ) + + def mapper(example: dict): + batch = collate_fn([example]) + return get_num_samples_in_batch(batch) + + token_lens = dataloader.dataset.map( # pyright: ignore + mapper, + batched=False, + num_proc=num_cpus_to_use, + desc='List of Token length', + ) + + return token_lens + + +def get_num_samples_in_batch(batch: dict) -> dict[str, int]: + decoder_only = True + + if not isinstance(batch, Mapping) or ( + 'attention_mask' not in batch and 'input_ids' not in batch + ): + raise ValueError( + 'get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key', + ) + + if not decoder_only and 'decoder_attention_mask' not in batch: + raise ValueError( + 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key', + ) + + # Count number of non padding tokens in batch + if 'attention_mask' in batch: + input_ids_tokens = int(torch.sum(batch['attention_mask']).item()) + else: + input_ids_tokens = batch['input_ids'].numel() + + # For encoder decoder models only + decoder_input_ids_tokens = 0 + if not decoder_only: + decoder_input_ids_tokens = int( + torch.sum(batch['decoder_attention_mask']).item(), + ) + + response_tokens = len(batch['labels']) if 'labels' in batch else 0 + + return { + 'ntokens': + input_ids_tokens + decoder_input_ids_tokens + response_tokens, + } + + +def check_HF_datasets(dataset_names_with_splits: list): + token = os.environ.get('HUGGING_FACE_HUB_TOKEN') + for dataset_name_with_split in dataset_names_with_splits: + dataset_name, split = os.path.split(dataset_name_with_split) + # make sure we have a dataset and split + if not dataset_name or not split: + return False, f"Failed to load Hugging Face dataset {dataset_name_with_split}. Please ensure that you include the split name (e.g. 'mosaicml/dolly_hhrlhf/train')." + # check user access to the dataset + try: + _ = dataset_info(dataset_name) + except: + token_warning = '' + if not token: + token_warning = ' If this is a private dataset, please set your HUGGING_FACE_HUB_TOKEN using: mcli create secret hf.' + return False, f"Failed to load Hugging Face dataset {dataset_name_with_split}. Please ensure that the dataset exists and that you have access to it. Remember to include the split name (e.g. 'mosaicml/dolly_hhrlhf/train')." + token_warning + # check that split exists + try: + splits = get_dataset_split_names(dataset_name) + except: # error raised in the case of multiple subsets + return False, f'Failed to load Hugging Face dataset {dataset_name_with_split}. Please make sure that the split is valid and that your dataset does not have subsets.' + if split not in splits: + return False, f'Failed to load Hugging Face dataset {dataset_name_with_split}. Split not found.' + return True, '' + + +def is_hf_dataset_path(path: str): + """Check if a given string is a dataset path used by Hugging Face. + + Args: + path (str): The string to be checked. + + Returns: + bool: True if the string is a dataset path, False otherwise. + """ + # Regular expression to match the dataset path pattern + pattern = r'^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+/?(train|validation|test)?/?$' + + return bool(re.match(pattern, path)) + + +def is_uc_delta_table(name: str): + """Name is in the form of catalog.scheme.tablename. + + Args: + name (str): a string folder/file/table path + Return: + (bool): True if name is valid UC delta table format + """ + return '.' in name and '/' not in name and '\\' not in name and len( + name.split('.'), + ) == 3 + + +def integrity_check(out: Union[str, Tuple[str, str]]): + """Check if the index file has integrity. + + If index is a cloud url, first download it to a temp local file. + + Args: + out (Union[str, Tuple[str,str]]): MDS dataset path + """ + + def count_shards(mds_root: str): + n_shard_files = 0 + cu = CloudUploader.get(mds_root, exist_ok=True, keep_local=True) + for o in cu.list_objects(): + if o.endswith('.mds'): + n_shard_files += 1 + return n_shard_files + + cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + + with tempfile.TemporaryDirectory() as temp_dir: + if cu.remote: + download_file( + os.path.join(cu.remote, 'index.json'), + os.path.join(temp_dir, 'index.json'), + timeout=60, + ) + actual_n_shard_files = count_shards(cu.remote) + local_merged_index_path = os.path.join(temp_dir, 'index.json') + else: + local_merged_index_path = os.path.join(cu.local, 'index.json') + actual_n_shard_files = count_shards(cu.local) + + merged_index = json.load(open(local_merged_index_path, 'r')) + n_shard_files = len({ + b['raw_data']['basename'] for b in merged_index['shards'] + }) + return n_shard_files == actual_n_shard_files + + +def parse_args( + tokenizer: str, + concat_tokens: int, + output_folder: str, + input_folder: str, + compression: str = 'zstd', + bos_text: str = '', + eos_text: str = '', + no_wrap: bool = False, + processes: int = 32, + reprocess: bool = True, +) -> Namespace: + parsed = Namespace( + tokenizer=tokenizer, + concat_tokens=concat_tokens, + output_folder=output_folder, + input_folder=input_folder, + eos_text=eos_text, + bos_text=bos_text, + no_wrap=no_wrap, + compression=compression, + processes=processes, + reprocess=reprocess, + ) + # Make sure we have needed concat options + if ( + parsed.concat_tokens is not None and + isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None + ): + raise ValueError( + 'When setting --concat_tokens, you must specify a --tokenizer', + ) + # now that we have validated them, change BOS/EOS to strings + if parsed.bos_text is None: + parsed.bos_text = '' + if parsed.eos_text is None: + parsed.eos_text = '' + return parsed + + +def download_and_convert_starargs(args: Tuple): + """Helper function to call download_and_convert with star args. + + This helps us use download_and_convert with mutiprocessing. + """ + return download_and_convert(*args) + + +def download_and_convert( + file_names: list[str], + output_folder: str, + input_folder: str, + tokenizer_name: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + trust_remote_code: bool, +): + """Downloads and converts text files to MDS format. + + Args: + file_names (list[str]): Files to process + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + tokenizer_name (str): Name of tokenizer to use + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + log.info(f'Starting download and conversion for {len(file_names)} files') + + object_store = maybe_create_object_store_from_uri(input_folder) + + # Download file_names + with tempfile.TemporaryDirectory() as tmp_dir: + log.info(f'Created temporary directory: {tmp_dir}') + downloading_iter = DownloadingIterable( + object_names=file_names, + output_folder=tmp_dir, + object_store=object_store, + ) + log.info(f'Initializing tokenizer: {tokenizer_name}') + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace + + # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up + # to the maximum sequence length + dataset = ConcatTokensFromFilesDataset( + files=downloading_iter, + max_length=concat_tokens, + tokenizer=tokenizer, + eos_text=eos_text, + bos_text=bos_text, + no_wrap=no_wrap, + ) + + num_samples = sum([1 for _ in dataset]) # pyright: ignore + return num_samples + + +def convert_text_to_mds( + tokenizer_name: str, + output_folder: str, + input_folder: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + processes: int, + args_str: str, + reprocess: bool, + trust_remote_code: bool, +): + """Convert a folder of text files to MDS format. + + Args: + tokenizer_name (str): Name of tokenizer to use + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + processes (int): The number of processes to use. + args_str (str): String representation of the arguments + reprocess (bool): Whether to always reprocess the given folder of text files + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + # Load the tokenizer once on the main process so that the files are cached to avoid race conditions + # in the Hugging Face load code + AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + + is_remote_output = is_remote_path(output_folder) + log.info(f'Output is remote: {is_remote_output}') + + object_names = get_object_names(input_folder) + if len(object_names) == 0: + log.error(f'No text files found in input folder: {input_folder}') + raise InputFolderMissingDataError(input_folder) + + # Check if the text files in the bucket have already been processed. + if not reprocess and is_already_processed( + output_folder, + args_str, + object_names, + ): + log.info( + f'Input folder {input_folder} is already processed at {output_folder} and ' + + + 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.', + ) + return + + # Use a temporary local directory if the output is remote and there are more than 1 processes + local_output_folder = tempfile.TemporaryDirectory( + ).name if is_remote_output else output_folder + log.info(f'Using local output folder: {local_output_folder}') + + if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: + log.error(f'Output folder is not empty: {output_folder}') + raise OutputFolderNotEmptyError(output_folder) + + if processes > 1: + log.info(f'Using multiprocessing with {processes} processes') + # Download and convert the text files in parallel + args = get_task_args( + object_names, + local_output_folder, + input_folder, + processes, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + with ProcessPoolExecutor(max_workers=processes) as executor: + pool = list(executor.map(download_and_convert_starargs, args)) + total_tokens = sum(pool) + + log.info('Merging MDS shards from each process') + # Merge the mds shards from each of the processes into a single folder + merge_shard_groups(local_output_folder) + else: + log.info('Using single process for download and conversion') + total_tokens = download_and_convert( + object_names, + local_output_folder, + input_folder, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + + return total_tokens + + +def plot_hist(data: Any, save_plot_path: Optional[bool] = None): + import matplotlib.pyplot as plt + + # Figure and Axis Setup + plt.figure(figsize=(10, 6)) + ax = plt.gca() + + # Histogram Plotting + data.hist(bins=100, edgecolor='black', color='skyblue', alpha=0.7, ax=ax) + + # Aesthetics + plt.title('Histogram of Token Counts') + plt.xlabel('Number of Tokens per Sample') + plt.ylabel('Count of Frequency') + + # Grid and Layout + plt.grid(axis='y', alpha=0.75) + plt.tight_layout() + + # Statistical Information (optional) + mean_val = data.mean() + median_val = data.median() + plt.axvline(mean_val, color='red', linestyle='dashed', linewidth=1) + plt.axvline(median_val, color='green', linestyle='dashed', linewidth=1) + _, max_ylim = plt.ylim() + plt.text(mean_val * 1.1, max_ylim * 0.9, f'Mean: {mean_val:.2f}') + plt.text(median_val * 1.1, max_ylim * 0.8, f'Median: {median_val:.2f}') + + if save_plot_path is not None: + plt.savefig(save_plot_path) + + # Show the Plot + plt.show() diff --git a/notebooks/validate_and_tokenize_data.ipynb b/notebooks/validate_and_tokenize_data.ipynb new file mode 100644 index 0000000000..1617fdd224 --- /dev/null +++ b/notebooks/validate_and_tokenize_data.ipynb @@ -0,0 +1,966 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "f275a21b-47d4-472c-972b-e2a84a597db2", + "showTitle": false, + "title": "" + } + }, + "source": [ + "# FM FT API: Data Validation and \\$Token Estimation\n", + "\n", + "#### Usage Scenario:\n", + "This notebook goes hand-in-hand with Databricks-Mosaicml's FT API. Our customers may find it useful in scenarios where there is a risk of data being malformed. It acts as a preventive measure to ensure data integrity and helps in cost assessment for the fine-tuning process.\n", + "\n", + "#### Script Purpose:\n", + "- **Not for Training**: This script is not utilized during the training process.\n", + "- **Ad-Hoc Validation**: It serves as an ad-hoc utility for users to run independently prior to starting fine-tuning.\n", + "- **Data Verification**: Its primary function is to validate the user's data before they invoke the Fine-Tuning (FT) API.\n", + "- **Cost Estimation**: Users can estimate the cost implications with this script.\n", + "\n", + "#### Note on Long-Term Solution:\n", + "- **Future Development**: We are in the process of developing a long-term data preparation service, which will eventually replace this script.\n", + "\n", + "#### User Defines:\n", + "- The inputs to this validation script is assumed to be the same or a subset of the FT API arguments, i.e., a configuration like below. \n", + "- For the reference, FT API expects following\n", + "```\n", + "cfg = {\n", + " model: str,\n", + " train_data_path: str,\n", + " save_folder: str,\n", + " *,\n", + " task_type: Optional[str] = \"INSTRUCTION_FINETUNE\",\n", + " eval_data_path: Optional[str] = None,\n", + " eval_prompts: Optional[List[str]] = None,\n", + " custom_weights_path: Optional[str] = None,\n", + " training_duration: Optional[str] = None,\n", + " learning_rate: Optional[float] = None,\n", + " context_length: Optional[int] = None,\n", + " experiment_trackers: Optional[List[Dict]] = None,\n", + " disable_credentials_check: Optional[bool] = None,\n", + " timeout: Optional[float] = 10,\n", + " future: Literal[False] = False,\n", + "}\n", + "``` " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "3d08a21c-9f5a-4ad2-af85-e016335cc53d", + "showTitle": false, + "title": "" + } + }, + "source": [ + "# Installation" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "6f330be7-ff76-4fa2-928f-396367b359ea", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip uninstall -y llm-foundry" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "6122e872-44b8-48a3-af61-4b907fc0a71f", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "dbutils.library.restartPython()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "34e0a248-1d33-4379-841b-6d7d123bbc8a", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install --upgrade --no-deps git+https://github.com/mosaicml/llm-foundry.git@byod/data_validation\n", + "%pip install 'mosaicml[libcloud,wandb,oci,gcs]>=0.23.4,<0.24'\n", + "%pip install 'mlflow>=2.14.1,<2.16'\n", + "%pip install 'transformers>=4.43.2,<4.44'\n", + "%pip install \"mosaicml-streaming>=0.8.0,<0.9\"\n", + "%pip install 'catalogue>=2,<3'\n", + "%pip install 'beautifulsoup4>=4.12.2,<5'\n", + "%pip install -U datasets\n", + "%pip install omegaconf\n", + "%pip install einops\n", + "%pip install sentencepiece" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "d9a3d8a4-c89a-40a6-8093-6c2afc2ae08d", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "dbutils.library.restartPython()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "0dcd849e-a35f-4999-acbe-6370c7a29294", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import json\n", + "import tempfile\n", + "import random\n", + "import numpy as np\n", + "import pandas as pd \n", + "from collections import defaultdict\n", + "from argparse import ArgumentParser, Namespace\n", + "from typing import cast \n", + "\n", + "import datasets \n", + "\n", + "from llmfoundry.utils import (create_om_cfg, token_counts_with_collate, \n", + " check_HF_datasets, is_hf_dataset_path, is_uc_delta_table,\n", + " integrity_check, convert_text_to_mds, parse_args, plot_hist,\n", + ")\n", + "\n", + "from llmfoundry.data.finetuning.tasks import (_validate_chat_formatted_example,\n", + " _tokenize_prompt_response_formatted_example,\n", + " _get_example_type, ChatFormattedDict, PromptResponseDict )\n", + "import transformers\n", + "transformers.logging.set_verbosity_error()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "3a513cdd-967d-4a87-b56f-340053fa79cd", + "showTitle": false, + "title": "" + } + }, + "source": [ + "# Instruction Fine Tuning" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "cfebdfdf-b87c-4a77-b97c-4697566a55fa", + "showTitle": false, + "title": "" + } + }, + "source": [ + "### Fine-Tuning API Arguments Configuration\n", + "\n", + "This section of the notebook is dedicated to setting up the parameters for the validation notebook. These parameters should be identical to what you specify in Finetuning API. \n", + "\n", + "**Fine-Tuning API Arguments (FT_API_args):**\n", + "\n", + "- model: Specifies the model to be used for fine-tuning. E.g., 'EleutherAI/gpt-neox-20b'\n", + "- train_data_path: The path to the training data. It can be either a huggingface dataset, a path to a jsonl file or a delta table.\n", + "- task_type: Defines the type of task for which the training strategy will be applied. It is either 'INSTRUCTION_FINETUNE' or 'CHAT_COMPLETION' or 'CONTINUED_PRETRAIN'.\n", + "- training_duration: The duration of the training process, expressed in numerical terms (e.g., 3) with units of training epochs.\n", + "- context_length: Specifies the context length of the model, set to 2048. This determines how many tokens the model considers for each training example.\n", + "\n", + "**Temporary Data Path Configuration:**\n", + "\n", + "- temporary_jsonl_data_path: Defines a filesystem path where temporary data related to the training process will be stored. You need to make sure the path should not be shared by other users on the cluster, as it costs contention.\n", + "- Environment variables for Hugging Face caches (HF_DATASETS_CACHE) are set to '/tmp/', directing dataset caching to a temporary directory.\n", + "\n", + "**[Supported Models by FT API](https://docs.mosaicml.com/projects/mcli/en/latest/finetuning/finetuning.html#supported-models):**. \n", + "\n", + "You need to specify context length based on the model mapping below.\n", + "```\n", + "ft_models = {\n", + " 'mosaicml/mpt-7b-8k': 8192, \n", + " 'mosaicml/mpt-7b': 2048,\n", + " 'mosaicml/mpt-30b': 8192,\n", + " 'meta-llama/Llama-2-13b-hf': 4096,\n", + " 'meta-llama/Llama-2-7b-hf': 4096,\n", + " 'meta-llama/Llama-2-70b-hf': 4096,\n", + " 'codellama/CodeLlama-7b-hf': 16384,\n", + " 'codellama/CodeLlama-13b-hf': 16384,\n", + " 'codellama/CodeLlama-34b-hf': 16384,\n", + " 'mistralai/Mistral-7B-v0.1': 32768,\n", + "}\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "0d1f2e9e-db40-41fd-a6b9-bb4757db08b0", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "# Make sure you have write access to the ``home`` directory\n", + "home = os.path.join('/tmp', 'ift')\n", + "os.makedirs(home, exist_ok=True)\n", + "os.chdir(home)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "a30e53a6-d3cb-454b-82c0-2b48ca3dbf55", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "FT_API_args = Namespace(\n", + " model= 'mosaicml/mpt-7b', # Other examples: 'EleutherAI/gpt-neox-20b',\n", + " train_data_path= 'mosaicml/dolly_hhrlhf/train', # Other examples: '/path/to/train.jsonl', 'catalog.schema.table', 'iamroot/chat_formatted_examples/train', \n", + " task_type='INSTRUCTION_FINETUNE', # 'CHAT_COMPLETION'\n", + " training_duration=3,\n", + " context_length=2048,\n", + ")\n", + "\n", + "temporary_jsonl_data_path = os.path.join(home, 'ft_data_11Jan24_3/train')\n", + "os.environ['HF_DATASETS_CACHE'] = os.path.join(home, 'hf_cache')\n", + "os.makedirs(temporary_jsonl_data_path, exist_ok=True)\n", + "os.makedirs(os.environ['HF_DATASETS_CACHE'], exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "39c45005-1a77-4162-b9e4-bd8df6f5ec69", + "showTitle": false, + "title": "" + } + }, + "source": [ + "#### Data Loading\n", + "\n", + "The IFT data needs to stay with a format \n", + "```\n", + "prompt: xxx\n", + "response or completion: yyy\n", + "```\n", + "\n", + "Based on FT_API_args.train_data_path, we will select an ingestion method from one of the three options below:\n", + "\n", + "- Option-1. data is a JSONL file which stores in an object store supported by Composer.\n", + "- Option-2. data is a Huggingface dataset ID. Note you need to provide a split as well. \n", + "- Option-3. data is a delta table. " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "751d8e3a-156c-432c-8e6e-a1530a5a9dc5", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "raw_dataset = None\n", + "\n", + "if is_hf_dataset_path(FT_API_args.train_data_path):\n", + " check_HF_datasets(FT_API_args.train_data_path)\n", + " dataset_id, split = '/'.join(FT_API_args.train_data_path.split('/')[:2]), FT_API_args.train_data_path.split('/')[-1] \n", + " raw_dataset = datasets.load_dataset(dataset_id, split=split) \n", + "else:\n", + " if is_uc_delta_table(FT_API_args.train_data_path): \n", + " df = spark.read.table(FT_API_args.train_data_path).toPandas()\n", + " df.to_json(os.path.join(temporary_jsonl_data_path, 'data.jsonl'), orient='records', lines=True)\n", + " raw_dataset = datasets.Dataset.from_pandas(df) \n", + " FT_API_args.train_data_path = temporary_jsonl_data_path\n", + " else: \n", + " # train_data_path is a jonsl file (local/remote)\n", + " from composer.utils import dist, get_file, parse_uri \n", + " data_path = FT_API_args.train_data_path \n", + " backend, _, _ = parse_uri(data_path)\n", + " if backend not in ['', None]: # It's a remote path, download before loading it\n", + " with tempfile.TemporaryDirectory() as tmp_dir:\n", + " destination = os.path.join(tmp_dir, 'data.jsonl')\n", + " get_file(data_path, destination)\n", + " df = pd.read_json(destination, orient='records', lines=True) \n", + " else: \n", + " df = pd.read_json(data_path, orient='records', lines=True) \n", + "\n", + " raw_dataset = datasets.Dataset.from_pandas(df)\n", + " FT_API_args.train_data_path = os.path.dirname(data_path)\n", + "\n", + "if raw_dataset is None: \n", + " raise RuntimeError(\"Can't find a proper ingestion method\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "06d46367-bd32-473a-9f16-1b34a8dd9356", + "showTitle": false, + "title": "" + } + }, + "source": [ + "#### Data Quality Checks on the Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "1a28320a-a2a1-4f3c-a0cd-ad6045a24f64", + "showTitle": false, + "title": "" + } + }, + "source": [ + "\n", + "This section of the notebook performs a series of checks on the initial dataset to ensure its quality and expected format. This process ensures that the dataset adheres to the expected structure and contains the necessary keys for further processing. The checks are outlined below.\n", + "\n", + "1. The total number of examples in the dataset is printed.\n", + "2. The first example from the dataset is displayed. This provides a quick glimpse into the data structure and format.\n", + "3. Data Format Validation:\n", + "- The dataset is expected to consist of dictionary-like objects (key-value pairs). A check is performed to validate this structure.\n", + "Each example in the dataset is examined for its compliance with the expected format.\n", + "4. Key Presence Validation:\n", + "- Allowed prompt and response keys, chat roles are defined in [llmfoundry](https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/data/finetuning/tasks.py): _ALLOWED_RESPONSE_KEYS and _ALLOWED_PROMPT_KEYS and _ALLOWED_ROLES.\n", + "- For prompt response dataset, the script checks for the presence of at least one prompt key and one response key in each example.\n", + " - Prompt Validation: Each example is checked for the presence of keys defined in _ALLOWED_PROMPT_KEYS. If no valid prompt key is found, it is counted as a format error. \n", + " - Response Validation: Similarly, each example is checked for the presence of keys defined in _ALLOWED_RESPONSE_KEYS. An absence of a valid response key is also counted as a format error.\n", + "- For chat formatted dataset, the script checks if the message content is formatted valid by calling [_validate_chat_formatted_example](https://github.com/mosaicml/llm-foundry/blob/cffd75e94e5c53b1b14c67cd17e0916fecfd0e16/llmfoundry/data/finetuning/tasks.py#L130) helper function.\n", + "Error Reporting:\n", + "\n", + "If any format errors are found during the checks, they are reported.\n", + "A summary of errors is printed, categorizing them into types like data_type (non-dictionary data), missing_prompt, and missing_response.\n", + "If no errors are found, a congratulatory message is displayed, indicating that all checks have passed successfully." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "9b89b5c6-bf3a-4425-8645-4840dfeb0848", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "# Initial dataset stats\n", + "print(\"Num examples:\", len(raw_dataset))\n", + "print(\"First example:\")\n", + "for ex in raw_dataset: \n", + " print(ex)\n", + " print() \n", + " break \n", + "\n", + "format_errors = defaultdict(int)\n", + "\n", + "for example in raw_dataset:\n", + " try: \n", + " example_format = _get_example_type(ex)\n", + " except ValueError:\n", + " format_errors[\"unknown example type\"] += 1 \n", + " continue \n", + "\n", + " if example_format == 'chat':\n", + " try: \n", + " chat_example = cast(ChatFormattedDict, example)\n", + " _validate_chat_formatted_example(chat_example)\n", + " except Exception as e: \n", + " format_errors['chat_format_error'] += 1 \n", + " print(e)\n", + " break \n", + "\n", + " elif example_format == 'prompt_response':\n", + " try:\n", + " prompt_response_example: PromptResponseDict = cast(\n", + " PromptResponseDict, example)\n", + " except Exception as e: \n", + " format_errors['prompt_response_format_error'] += 1 \n", + " print(e)\n", + " break \n", + " \n", + "if format_errors:\n", + " print(\"Oops! Found errors:\")\n", + " for k, v in format_errors.items():\n", + " print(f\"{k}: {v}\")\n", + "else:\n", + " print(\"Congratulations! No errors found\") " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "9713a0ce-80f4-4187-b10b-4223b17fe4c1", + "showTitle": false, + "title": "" + } + }, + "source": [ + "#### Token Estimation\n", + "\n", + "Tokenize the raw dataset and let's some statistics of the tokens and estimate the overall cost based on default trainining duration\n", + "We will iterate over the dataloader and sum the number of tokens from each batch. " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "439d3bd1-0569-456f-8872-3dbafd50cbd7", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "n_epochs = FT_API_args.training_duration if FT_API_args.training_duration is not None else 1 \n", + "batch_tokens = token_counts_with_collate(FT_API_args)\n", + "n_billing_tokens_in_dataset = sum(batch_tokens['ntokens'])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "7249e9e6-1ea7-4fc9-8959-8a17d62a9fb4", + "showTitle": false, + "title": "" + } + }, + "source": [ + "Finetuning API will internally ingest the dataset and run tokenization with the selected tokenizer. \n", + "The output dataset will be a collection of samples. Each sample is a collection of token ids represented as integers. \n", + "We generate a histogram that visualizes the distribution of frequency of token counts in samples in the dataset. \n", + "The visualization aids in identifying patterns, outliers, and central tendencies in the token distribution." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "d85eaa02-ee39-4c6b-b14b-8ea1da8bf74d", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "print(f\"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be used by the model during training\")\n", + "print(f\"Assume you'll train for {n_epochs} epochs on this dataset\")\n", + "print(f\"Then ~{n_epochs * n_billing_tokens_in_dataset} tokens will be running through the model during training\")\n", + "plot_hist(pd.Series(batch_tokens['ntokens']))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "6699f47f-9b53-47da-95c0-b862c5826d0a", + "showTitle": false, + "title": "" + } + }, + "source": [ + "# Continued Pretrain" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "dd37fdce-62d0-493e-bfa9-d823634b2a0d", + "showTitle": false, + "title": "" + } + }, + "source": [ + "### Continued Pretrain API Arguments Configuration\n", + "\n", + "Similar to Instruction Finetune, you need to specify\n", + "\n", + "**Fine-Tuning API Arguments (FT_API_args):**\n", + "\n", + "- model: Specifies the model to be used for fine-tuning. E.g., 'EleutherAI/gpt-neox-20b'\n", + "- train_data_path: The path to the training data. We currently only support a (remote/local) path to a collection of .txt files.\n", + "- task_type: Defines the type of task for which the training strategy will be applied. It is either 'INSTRUCTION_FINETUNE' or 'CONTINUED_PRETRAIN'.\n", + "- training_duration: The duration of the training process, expressed in numerical terms (e.g., 3) with units of training epochs.\n", + "- context_length: Specifies the context length of the model, set to 2048. This determines how many tokens the model considers for each training example. For Continued Pretraining, we concatenate tokens to form samples of length equal to context_length\n", + "\n", + "**Temporary Data Path Configuration:**\n", + "\n", + "- temporary_mds_output_path: Defines a filesystem path where notebook running data can be stored. You need to make sure the path should not be shared by other users on the cluster, as it costs contention. For example, you can make it distinguishable by adding your username to the path.\n", + "\n", + "**[Supported Models by FT API](https://docs.mosaicml.com/projects/mcli/en/latest/finetuning/finetuning.html#supported-models):**. \n", + "\n", + "You need to specify context length based on the model mapping below.\n", + "```\n", + "ft_models = {\n", + " 'mosaicml/mpt-7b-8k': 8192, \n", + " 'mosaicml/mpt-7b': 2048,\n", + " 'mosaicml/mpt-30b': 8192,\n", + " 'meta-llama/Llama-2-13b-hf': 4096,\n", + " 'meta-llama/Llama-2-7b-hf': 4096,\n", + " 'meta-llama/Llama-2-70b-hf': 4096,\n", + " 'codellama/CodeLlama-7b-hf': 16384,\n", + " 'codellama/CodeLlama-13b-hf': 16384,\n", + " 'codellama/CodeLlama-34b-hf': 16384,\n", + " 'mistralai/Mistral-7B-v0.1': 32768,\n", + "}\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "7a773173-2a7f-4605-a7ca-0ece52a905f1", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "FT_API_args = Namespace(\n", + " model= 'mosaicml/mpt-7b',\n", + " train_data_path= os.path.join(home, 'ABT'), # this is the path to your collection of txt files\n", + " task_type='CONTINUED_PRETRAIN',\n", + " training_duration=3,\n", + " context_length=8,\n", + ")\n", + "temporary_mds_output_path = os.path.join(home, 'mds_data_11Jan24_5')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "fc2e4e8b-7700-47c4-bb21-ae4c389f39a2", + "showTitle": false, + "title": "" + } + }, + "source": [ + "Generate a synthetic dataset. Replace train_data_path with your raw data path in practice." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "10f08422-5091-4e64-b3f7-54928584cd60", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "def generate_synthetic_dataset(folder_path, num_files=128):\n", + " \"\"\"Generate a synthetic dataset of text files with random words.\"\"\"\n", + " def generate_random_words(num_words=50):\n", + " words = [\"apple\", \"banana\", \"cherry\", \"date\", \"elderberry\", \"fig\", \"grape\", \"honeydew\", \"kiwi\", \"lemon\", \"mango\", \"nectarine\", \"orange\", \"papaya\", \"quince\", \"raspberry\", \"strawberry\", \"tangerine\", \"ugli\", \"vanilla\", \"watermelon\", \"xigua\", \"yam\", \"zucchini\"]\n", + " return ' '.join(random.choice(words) for _ in range(num_words))\n", + "\n", + " if not os.path.exists(folder_path):\n", + " os.makedirs(folder_path)\n", + " \n", + " for i in range(num_files):\n", + " file_path = os.path.join(folder_path, f\"file_{i}.txt\")\n", + " with open(file_path, 'w') as file:\n", + " file.write(generate_random_words())\n", + "\n", + " print(f\"Generated {num_files} files in '{folder_path}'.\")\n", + "\n", + "generate_synthetic_dataset(FT_API_args.train_data_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "34bcddfb-7d4f-4243-bd02-7ac3e0dce711", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "!rm -rf {temporary_mds_output_path}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "c21e7d1b-db34-4e5d-b6d9-190dc75170d3", + "showTitle": false, + "title": "" + } + }, + "source": [ + "#### Ingestion, Tokenization and Materialization\n", + "\n", + "CPT takes a folder of txt files as input. It tokenize the text fields and materialize as a streaming dataset of MDS format. \n", + "\n", + "FT API uses [llmfoundry/scripts/data_prep/convert_text_to_mds.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/data_prep/convert_text_to_mds.py) to download all the txt files and convert them to MDS. \n", + "\n", + "In this notebook, we provide two additional approaches via Spark and Dask. \n", + "\n", + "**Warning** CPT datasets are normally much larger than IFT, so the tokenization and materialization can be very time consuming. " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "23a56b37-9c1e-4a2e-b8eb-96562ba104f0", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "import os\n", + "os.makedirs(temporary_mds_output_path, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "d2f32fa2-a9b5-4ae8-a54a-0ea329ad1176", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "cfg, tokenizer = create_om_cfg(FT_API_args)\n", + "\n", + "input_folder = FT_API_args.train_data_path\n", + "output_folder = temporary_mds_output_path\n", + "concat_tokens = FT_API_args.context_length\n", + "tokenizer_name = FT_API_args.model\n", + "\n", + "# Run convert_text_to_mds.py and dump MDS dataset to \"save_folder\"\n", + "args = parse_args(tokenizer_name, concat_tokens, output_folder, input_folder)\n", + "\n", + "n_samples = convert_text_to_mds(\n", + " tokenizer_name=args.tokenizer,\n", + " output_folder=args.output_folder,\n", + " input_folder=args.input_folder,\n", + " concat_tokens=args.concat_tokens,\n", + " eos_text=args.eos_text,\n", + " bos_text=args.bos_text,\n", + " no_wrap=args.no_wrap,\n", + " compression=args.compression,\n", + " processes=1,\n", + " reprocess=True,\n", + " args_str=str(args), \n", + " trust_remote_code=False)\n", + "\n", + "n_billing_tokens_in_dataset = n_samples * concat_tokens" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "298eb990-9160-4e1b-958f-33dd2c11b54b", + "showTitle": false, + "title": "" + } + }, + "source": [ + "#### Token Estimation" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "5bc58cb3-0a19-4512-9584-642f0a2be4df", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "MAX_TOKENS_PER_EXAMPLE = FT_API_args.context_length if FT_API_args.context_length is not None else 4096\n", + "TARGET_EPOCHS = FT_API_args.training_duration if FT_API_args.training_duration is not None else 1 \n", + "n_epochs = TARGET_EPOCHS\n", + "\n", + "print(f\"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training\")\n", + "print(f\"By default, you'll train for {n_epochs} epochs on this dataset\")\n", + "print(f\"By default, ~{n_epochs * n_billing_tokens_in_dataset} tokens will be used in training\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "e123669c-2f77-4d66-93eb-04efd546f39f", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "dashboards": [], + "environmentMetadata": { + "base_environment": "", + "client": "1" + }, + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 2 + }, + "notebookName": "validate_and_tokenize_data (1)", + "widgets": {} + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tests/utils/test_validation_notebook_utils.py b/tests/utils/test_validation_notebook_utils.py new file mode 100644 index 0000000000..3024504855 --- /dev/null +++ b/tests/utils/test_validation_notebook_utils.py @@ -0,0 +1,98 @@ +import unittest +from unittest.mock import patch, MagicMock +from argparse import Namespace +import os +import pandas as pd +from collections import defaultdict +from llmfoundry.utils import token_counts_with_collate, create_om_cfg, convert_text_to_mds, parse_args +from llmfoundry.data.finetuning.tasks import _validate_chat_formatted_example, _get_example_type +import datasets + +class TestTrainingWorkflow(unittest.TestCase): + + @patch('llmfoundry.utils.is_hf_dataset_path') + @patch('llmfoundry.utils.is_uc_delta_table') + @patch('datasets.load_dataset') + def test_data_loading(self, mock_load_dataset, mock_is_uc_delta_table, mock_is_hf_dataset_path): + mock_is_hf_dataset_path.return_value = True + mock_load_dataset.return_value = MagicMock() + FT_API_args = Namespace(model='mosaicml/mpt-7b', train_data_path='mosaicml/dolly_hhrlhf/train', task_type='INSTRUCTION_FINETUNE', training_duration=3, context_length=2048) + + if mock_is_hf_dataset_path(FT_API_args.train_data_path): + dataset_id, split = '/'.join(FT_API_args.train_data_path.split('/')[:2]), FT_API_args.train_data_path.split('/')[-1] + raw_dataset = datasets.load_dataset(dataset_id, split=split) + else: + self.fail("HF dataset path mock did not return True") + + self.assertIsNotNone(raw_dataset) + + @patch('pandas.DataFrame.to_json') + @patch('pandas.read_json') + @patch('llmfoundry.utils.is_hf_dataset_path', return_value=False) + @patch('llmfoundry.utils.is_uc_delta_table', return_value=True) + def test_delta_table_data_loading(self, mock_is_uc_delta_table, mock_is_hf_dataset_path, mock_read_json, mock_to_json): + mock_df = pd.DataFrame({'example': [1, 2, 3]}) + mock_read_json.return_value = mock_df + FT_API_args = Namespace(model='mosaicml/mpt-7b', train_data_path='catalog.schema.table', task_type='INSTRUCTION_FINETUNE', training_duration=3, context_length=2048) + + df = mock_df + df.to_json('dummy_path', orient='records', lines=True) + raw_dataset = datasets.Dataset.from_pandas(df) + + self.assertIsNotNone(raw_dataset) + mock_to_json.assert_called_once() + + def test_data_quality_checks(self): + raw_dataset = [{'prompt': 'test prompt', 'response': 'test response'}] + format_errors = defaultdict(int) + + for example in raw_dataset: + try: + example_format = _get_example_type(example) + except ValueError: + format_errors["unknown example type"] += 1 + continue + + if example_format == 'chat': + try: + _validate_chat_formatted_example(example) + except Exception: + format_errors['chat_format_error'] += 1 + + elif example_format == 'prompt_response': + try: + _ = example + except Exception: + format_errors['prompt_response_format_error'] += 1 + + self.assertEqual(len(format_errors), 0) + + @patch('llmfoundry.utils.token_counts_with_collate') + def test_token_estimation(self, mock_token_counts_with_collate): + mock_token_counts_with_collate.return_value = {'ntokens': [1000, 2000, 3000]} + FT_API_args = Namespace(model='mosaicml/mpt-7b', task_type='INSTRUCTION_FINETUNE', training_duration=3) + + n_epochs = FT_API_args.training_duration if FT_API_args.training_duration is not None else 1 + batch_tokens = token_counts_with_collate(FT_API_args) + n_billing_tokens_in_dataset = sum(batch_tokens['ntokens']) + + self.assertEqual(n_billing_tokens_in_dataset, 6000) + + @patch('llmfoundry.utils.create_om_cfg') + @patch('llmfoundry.utils.convert_text_to_mds') + def test_continued_pretrain(self, mock_convert_text_to_mds, mock_create_om_cfg): + FT_API_args = Namespace(model='mosaicml/mpt-7b', train_data_path='/tmp/ABT', task_type='CONTINUED_PRETRAIN', training_duration=3, context_length=8) + temporary_mds_output_path = '/tmp/mds_data_11Jan24_5' + + cfg, tokenizer = MagicMock(), MagicMock() + mock_create_om_cfg.return_value = (cfg, tokenizer) + + n_samples = mock_convert_text_to_mds.return_value = 10 + n_billing_tokens_in_dataset = n_samples * FT_API_args.context_length + + self.assertEqual(n_billing_tokens_in_dataset, 80) + mock_convert_text_to_mds.assert_called_once() + +if __name__ == '__main__': + unittest.main() +