diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index ef0e32c5c5..62bfdb818a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -282,15 +282,8 @@ def _tokenize_chat_formatted_example( } -<<<<<<< HEAD def _validate_prompt_response_formatted_example(example: PromptResponseDict): """Validate expected keys.""" -======= -def _tokenize_prompt_response_formatted_example( - example: PromptResponseDict, - tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: - """Tokenize a formatted example and validate expected keys.""" ->>>>>>> c404dc7ec03897c283a3cae0473a65257b51a2aa example_keys = set(example.keys()) prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS) response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS) @@ -322,7 +315,6 @@ def _tokenize_prompt_response_formatted_example( f'Unable to tokenize example because {response_key} was not a string. {example=}' ) -<<<<<<< HEAD return prompt, response def _tokenize_prompt_response_formatted_example( @@ -331,8 +323,6 @@ def _tokenize_prompt_response_formatted_example( """Tokenize a formatted example and validate expected keys.""" prompt, response = _validate_prompt_response_formatted_example(example) -======= ->>>>>>> c404dc7ec03897c283a3cae0473a65257b51a2aa # Note: We default to the tokenizer's add_bos_token and add_eos_token behavior here # (which we do not do for chat-formatted examples). This is because chat examples specifically # go through the tokenizer's `apply_chat_template` method, which handles special tokens, diff --git a/notebooks/validate_and_tokenize_data.ipynb b/notebooks/validate_and_tokenize_data.ipynb index b749fc46f9..926b6a11aa 100644 --- a/notebooks/validate_and_tokenize_data.ipynb +++ b/notebooks/validate_and_tokenize_data.ipynb @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -141,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -161,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -185,14 +185,18 @@ "import pandas as pd \n", "from collections import defaultdict\n", "from argparse import ArgumentParser, Namespace\n", + "from typing import cast \n", "\n", "import datasets \n", "\n", - "from llmfoundry.utils import (create_om_cfg, token_counts_and_validation, token_counts, \n", + "from llmfoundry.utils import (create_om_cfg, token_counts_with_collate, token_counts, \n", " check_HF_datasets, is_hf_dataset_path, is_uc_delta_table,\n", " pandas_processing_fn, integrity_check, convert_text_to_mds, parse_args, \n", " _args_str, plot_hist, dataframe_to_mds)\n", "\n", + "from llmfoundry.data.finetuning.tasks import (_validate_chat_formatted_example,\n", + " _validate_prompt_response_formatted_example,\n", + " _get_example_type, ChatFormattedDict, PromptResponseDict )\n", "import transformers\n", "transformers.logging.set_verbosity_error()" ] @@ -268,7 +272,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -291,7 +295,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -308,7 +312,7 @@ "source": [ "FT_API_args = Namespace(\n", " model= 'mosaicml/mpt-7b', # Other examples: 'EleutherAI/gpt-neox-20b',\n", - " train_data_path= 'mosaicml/dolly_hhrlhf/train', # Other examples: '/path/to/train.jsonl', 'catalog.schema.table'\n", + " train_data_path= 'mosaicml/dolly_hhrlhf/train', # Other examples: '/path/to/train.jsonl', 'catalog.schema.table', 'iamroot/chat_formatted_examples/train', \n", " task_type='INSTRUCTION_FINETUNE',\n", " training_duration=3,\n", " context_length=2048,\n", @@ -352,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -454,7 +458,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -477,30 +481,33 @@ " print() \n", " break \n", "\n", - "_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}\n", - "_ALLOWED_PROMPT_KEYS = {'prompt'}\n", "format_errors = defaultdict(int)\n", "\n", - "for ex in raw_dataset:\n", - " if not isinstance(ex, dict):\n", - " format_errors[\"data_type\"] += 1 \n", + "for example in raw_dataset:\n", + " try: \n", + " example_format = _get_example_type(ex)\n", + " except ValueError:\n", + " format_errors[\"unknown example type\"] += 1 \n", " continue \n", - " \n", - " found = False \n", - " for key in _ALLOWED_PROMPT_KEYS:\n", - " prompts = ex.get(key, None)\n", - " if prompts:\n", - " found = True \n", - " if not found: \n", - " format_errors[\"missing_prompt\"] += 1\n", - "\n", - " found = False\n", - " for key in _ALLOWED_RESPONSE_KEYS: \n", - " responses = ex.get(\"response\", None)\n", - " if responses: \n", - " found = True \n", - " if not found:\n", - " format_errors[\"missing_response\"] += 1\n", + "\n", + " if example_format == 'chat':\n", + " try: \n", + " chat_example = cast(ChatFormattedDict, example)\n", + " _validate_chat_formatted_example(chat_example)\n", + " except Exception as e: \n", + " format_errors['chat_format_error'] += 1 \n", + " print(e)\n", + " break \n", + "\n", + " elif example_format == 'prompt_response':\n", + " try:\n", + " prompt_response_example: PromptResponseDict = cast(\n", + " PromptResponseDict, example)\n", + " _validate_prompt_response_formatted_example(prompt_response_example)\n", + " except Exception as e: \n", + " format_errors['prompt_response_format_error'] += 1 \n", + " print(e)\n", + " break \n", " \n", "if format_errors:\n", " print(\"Oops! Found errors:\")\n", @@ -527,12 +534,13 @@ "source": [ "#### Token Estimation\n", "\n", - "Tokenize the raw dataset and let's some statistics of the tokens and estimate the overall cost based on default trainining duration" + "Tokenize the raw dataset and let's some statistics of the tokens and estimate the overall cost based on default trainining duration\n", + "We will iterate over the dataloader and sum the number of tokens from each batch. " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -548,7 +556,7 @@ "outputs": [], "source": [ "n_epochs = FT_API_args.training_duration if FT_API_args.training_duration is not None else 1 \n", - "batch_tokens = token_counts(FT_API_args)\n", + "batch_tokens = token_counts_with_collate(FT_API_args)\n", "n_billing_tokens_in_dataset = sum(batch_tokens['ntokens'])" ] }, @@ -575,7 +583,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -666,7 +674,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -711,7 +719,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -747,7 +755,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -793,8 +801,19 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "23a56b37-9c1e-4a2e-b8eb-96562ba104f0", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "import os\n", @@ -803,7 +822,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -866,7 +885,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -892,7 +911,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { @@ -916,7 +935,7 @@ "notebookMetadata": { "pythonIndentUnit": 2 }, - "notebookName": "validate_and_tokenize_data", + "notebookName": "validate_and_tokenize_data (1)", "widgets": {} }, "kernelspec": { @@ -938,5 +957,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 0 }