From 8e8f4316a5f8d5d2180c15f37353acced1c935b7 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 14 Mar 2024 15:51:29 -0700 Subject: [PATCH] Validation (#1034) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add validation script * update * change token count function * reorganize cells * Add unit tests * Add a printout for CPT * update question * Add questions * Fix lints * update format * update * nb source * add validation script * update * change token count function * reorganize cells * Add unit tests * Add a printout for CPT * update question * Add questions * Fix lints * update format * update * nb source * Remove license insert for validation notebook * Add validation utils * Minor cleanups (#858) * nits * logger * add log * lint * update utils/__init__.py to include extra validation functions * update notebook * update * update * Read UC delta table (#773) * initial commit * use databricks-sql to read delta table and convert to json * update * update * update * add mocked unittest * Fix lints * update * update * restructure code * Add timer for optimizing * Add db-connect * add wrapper * update * add install dbconnect * update * update * patch dbconnect to allow multiple return formats * update * add arrow * use compression * clean up * Add cluster rt check * Fix lints * remove patch.py for CI * update * update * updat * update * fix tests * fix lint * update * update * Add more tests * update * update * update * change to download_json * update * fix lints * Add decompressed option for arrow * format json to jsonl * Add comments * Make cf_collect_type global option * fix comments * fix lints * fix comments * Fix lints * change to use workspaceclient * Add CPT support * Rewire method assignment logic * Fix bug in stripping https * Add tests for rewired method assignment logic * Fix lints * Fix lints * Removed logger set_level * Remove pyspark. It conflicts with databricks-connect * Update the comment * skip cluster version check when cluster_id is serverless * Add use_serverless flag * update tests with use_serverless flag * Fix lints --------- Co-authored-by: Xiaohan Zhang * Add download remote function to util * update * remove fused layernorm (#859) * update * update * update * update * update * update * update * update * update * Remove hardcoded combined.jsonl with a flag (#861) * Remove hardcoded combined.jsonl with a flag * update * change output_json_path output_json_folder --------- Co-authored-by: Xiaohan Zhang * bump (#828) * Add dask and dataframe_to_mds * update * update * update * update * Add notebook * update * update * remove script and tests, keep notebook * update * update * update * update * Always initialize dist (#864) * fix dev * lint * remove gpu * updated notebook * remove scripts keep notebook * update notebook. rephrase. * Logs upload URI (#850) * fix style etc. * fix * fix fix * fix fix fix * fix fix fix fix * removed unused dummy func * deleted tests to make the tests pass * tried adding back some tests to see if it triggers the issue * add test_hf_checkpointer.py but remove references to MPT * fix? * fixed test cases overlapping in strange side-effecty ways * update * Delta to JSONL conversion script cleanup and bug fix (#868) * Small test change * small cleanups * lint and precommit * lint and precommit * comments * another one * pr suggestion and use input param not args * fix mock (#872) * Add response tokens * update * fix regex (#877) * Precompute flash attention padding info (#880) * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * dummy data * undoing last commit * .. * .. * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * .. * .. --------- Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * add missing import (#882) * fsdp wrap refac (#883) * fsdp wrap refac * refac * refac * Update model download utils to support ORAS (#881) * wip * wip * Accept registry file for hostname * Make sure no sensitive info is surfaced in subprocess error * Refactor model downloading * Save HF hub files to local dir * fallback * Remove commented code * Update logging * Update HTP download args * Use files for ORAS * Update llmfoundry/utils/model_download_utils.py Co-authored-by: Irene Dea --------- Co-authored-by: Irene Dea * Update license (#887) Updates the license for 2024. New files will have a copyright year of 2024 inserted in the header. Existing files will not be changed. * Fix tiktoken add generation prompt (#890) * update * Upgrade Datasets version (#892) * Disable MDSWrite, return token counts * Bump transformers version to support Mixtral (#894) * Add `tokenizer-only` flag to only download tokenizers from HF or oras (#895) * Foundational Model API eval wrapper (#849) * FMAPI model wrapper * add chat wrapper too * revert * end line * formatting * less verbose * better error messages * Change plot settings * update notebook * update * update notebook * update * update notebook * Add better error for non-empty local output folder in convert_text_to_mds.py (#891) * Allow bool input for loggers (#897) * Allow bool input for loggers * Convert earlier on * Fix test case * Enable QK Group Norm (#869) * start qkgn * attn defaults for qk_gn * impl qk_gn * Update attention.py * Update attention.py * Update test_flash_triton_torch.py * Update attention.py * Update test_flash_triton_torch.py * Update attention.py * lint * Update attention.py * lint * add avlue error * Update attention.py * updt to include low precision groupnorm; * perf improvement * Revert "perf improvement" This reverts commit 2b62d5ecd21e13cb1bcd0883b3f6ebd1229e9d1d. * Revert "updt to include low precision groupnorm;" This reverts commit bca1c3383f5d2ea3009d4ee297ccc26db146cf20. * patch (#905) * Add new GC option (#907) * No symlinks at all for HF download (#908) * Adds support for chat formatted finetuning input data. (#884) * fix conflicting formatting linting guidelines * used older union operator for legacy support * did the same thing in another place * isort ignore specific lines * fixes * isort do not skip line * address comments * renamed some more things * split tests and add some verification for tokenization split * fix formatting * added docstrings * added end-to-end-test with HF dataset * fix code style * renamed file and fixed tests * use chat template diff * addressed comment * Update llmfoundry/data/finetuning/tasks.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update llmfoundry/data/finetuning/tasks.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * fixed type of TokenizedExample * use cast * use _ALLOWED_{PROMPT, RESPONSE}_KEYS * updated tests * fix * fix? * Update llmfoundry/data/finetuning/tasks.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update llmfoundry/data/finetuning/tasks.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Add flag to enable/disable param upload (#912) * Add flag to enable/disable param upload * Yapf * Apply suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Rename * Add to eval --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Add support for eval_loader & eval_subset_num_batches in async callback (#834) * Skip evalloader in training if using async eval * add support for subset_num_batches * remove todo * eval first * rename arg * fix * small updates * om * fix test * eval run config --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Add the model license file for mlflow (#915) * Warn instead of error on tokenizer-only with http (#904) * Fix fmapi_chat for instruct models and custom tokenizers (#914) * Fix fmapi_chat for instruct models and custom tokenizers * remove from tiktoken * fix * add tests * fix test, 0->1 * refactor * Make yamllint consistent with Composer (#918) * Create HF checkpointer model on meta device (#916) * Tiktoken chat format fix (#893) * sys prompt fix * remove eos tokens from chat formatter * fix dash issue (#919) * fix dash issue * fix * fix? * added unit test * fix fix * fix tests * fix fix tests * Fixes yaml linting (#920) * Adding deprecation warning for Flash Attention 1 and user warning against using Triton attention. (#921) * Add rich formatting to tracebacks (#927) * added rich traceback * sorted imports * added rich to eval * Changes to setup.py invalidate docker cache. Use branch name in dockerfile (#930) Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Remove .ci folder and move FILE_HEADER (#931) * Throw error when no EOS (#922) * bump (#934) * Update eval_gauntlet_callback.py with math.log2 (#821) Saw an automated ruff flag this, seems like a strict improvement and is marginally faster. Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Switch to the Composer integration of LoRA (works with FSDP) (#886) * Refactoring the function to accept list of metric names instead of a dictionary of metrics. (#938) * .. * undoing prev commit * Refactoring the function to accept list of metric names instead of dictionary * .. * .. * .. * .. * Remove extra call to .to and load_state_dict in hf checkpointer (#939) * Fixing the gen_attention_mask_in_length function to handle the case when sequence id is -1 due to attention masking (#940) * .. * undoing prev commit * fixing the gen_attention_mask_in_length function to handle the case when sequence id is -1 due to attention masking * Update modeling_mpt.py * .. --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update lora docs (#941) * fix (#942) * Retrieve license information when local files are provided for a pretrained model (#943) * Initial implementation to test * Add log for license overwrite * Use Path for input to _write_license_information * Set default --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Add and use VersionedDeprecationWarning (#944) * Add and use VersionedDeprecationWarning * Use remove_version instead. * Fix merge * Apply suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Bump llm-foundry version to 0.5.0 (#948) * Bump version to 0.5.0 * Remove deprecated features * Other cleanup * code quality * Fix chain-of-thought tasks (#824) * Skip flaky lion8b test (#598) * relax atol and add retries to reduce flakiness in lion8b timing test * add eval output logging * add back tasks * foo * add rlhf prompts * add rlhf prompts * add rlhf prompts * add rlhf prompts * add rlhf prompts * fix prompt * fix prompt * modify mcli * test * test * fix * added math dataset * edit yaml * prep gsm8k identically to eleuther * prep gsm8k identically to eleuther * add early stopping criteria * finish * debug * fix * bug * remove eval output logging callback * restore * fix * fix * fix composer verion * gauntlet v0.2.1 * gauntlet v0.2.1 * prep * prep * foo * restore * restore * restore mcli * fix precommit * fix * Update hf_eval.yaml * fix * fix * remove programming * update readme --------- Co-authored-by: dblalock Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Add finetuning streaming dataset conversion (#933) * add convert * fix * fix convert * add jsonl * revert setup * test precommit * pre-commit * test pre-commit * review comments * Update llmfoundry/data/finetuning/tasks.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update llmfoundry/data/finetuning/tasks.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update llmfoundry/data/finetuning/tasks.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update scripts/data_prep/convert_finetuning_dataset.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Add default signature to mlflow saved model (#952) * allow te to use meta device with deferred init (#958) * Update TUTORIAL.md (#957) * Update TUTORIAL.md fix indentation problem * Update TUTORIAL.md --------- Co-authored-by: Mihir Patel Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Bump mcli yaml foundry version to v0.5.0 (#959) * add finutuning with streaming dataset example (#945) * add convert * fix * fix convert * add jsonl * revert setup * test precommit * pre-commit * test pre-commit * v0 * review comments * temporarily trigger test * test * fix yaml * comments * comments * comments * add unit test * comments --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Add fully configurable activation checkpointing (#951) * add fully configurable activation checkpointing * fix format * fix format * add docstring to activation_checkpointing_fn * add block id range option in act ckpt * resolve conflict * add a check for blocks ids overlap in mapping * fix typo * update docstring * refactor * fix test * Apply suggestions from code review Co-authored-by: Mihir Patel * address comments * add build mapping as a helper func * fix format --------- Co-authored-by: Mihir Patel * Use create_model_version instead of register_model (#953) * Add streams support (#946) * add convert * fix * fix convert * add jsonl * revert setup * test precommit * pre-commit * test pre-commit * v0 * review comments * temporarily trigger test * test * add convert * fix * v0 * fix * fix MDS write * streams support * fake commit * fix setup * format * add back arxiv * trigger test * review comments * temporarily trigger test * test * add convert * fix * fix * fix MDS write * format * trigger test * fix * format * resolve conflicts * add back jsonl * fix yaml * comments * format * comments * comments * add unit test * comments * comments * merge * format * typo * Update llmfoundry/data/finetuning/dataloader.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Fix typo (#966) * Fix eval.py with lora (#965) * just remove it? * or not * fix * fix up * clean up * fix example yaml * precommit * add test * add memorysnapshot to callbacks (#810) Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Adding curriculum learning callback (experimental) (#954) * curriculum learning callback * curriculum learning callback * fixing types * dataset config types correct * dataset config retrieved correctly * access train dataloader correctly * load state dict defaults * get that damn dataloader * missed dat * dataspec L * dataset L * no logging, print is my best friend * save first dataset config * don't save new dataset config every single time * logging dataset state * have to set the damn timestamp. rip * remove logging * linting * pyright * removing rope... * Delete scripts/eval/local_data/.DS_Store * trailing comma is bacc * fixed docstring * fixed docstrings * no more funky stuff in save_dict * refactored, assuming before_load event in composer * lingint * bumped composer and streaming min versions * moved line * strengthened chat formatting validation (#960) * strengthened chat formatting validation * fix types * made assert messages more descriptive * used raise instead of assert, added type checks * added list type check * type error if no string content * add test case for new validation * relaxed type constraints to interface minimum * use Mapping and Iterable * fix mapping in type aliases too * iterable -> sequence * sequence -> list * Mapping -> Dict * use mapping again * fixed another one * updated message * factored out duplicate functions * dict -> mapping * add sequence * Add new base images and remove fa1 images (#970) * Add new ICL kwargs in eval.py and long_context yamls (#925) * add yamls w/ old links * load from max's public hf and parse hf datasets * update rest of tasks * add better logging * implemented leval tasks * move level * add level yaml * add str parsing to hf * wip * llm-foundry working with new parser * working w/ new parsing * fix old long context tasks * wip * wip * wip * wip * update to hf_parsing_map * rm defaults * fix parsing vars * update defaults again * rm merge conflict * fix gen_kwargs * rm old code path * fixups * wip * rm leval from pr * fix comments in yamls * add cot params * add fewshot_random_seed * fix early_stopping_criteria, fewshot_num_seed default * undo rm hf_eval * add fewshot_random_seed to test * add 64k tasks * add longer context, update composer versin * address comments * mixed * use seed by default * rm long_context_eval_8k.yaml * add longer context evals * mv yamls * eval gauntlet wip * update niah and wikiqa * fix linting * add default option * change defaults * fix linting * fix linting 2 --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Make Composer pins consistent with each other (#972) * Make turbo an optional dependency (#964) * Fix fewshot_random_seed default setting (#974) * del fewshot_random default, fix hf_eval, fix gauntlet readme * set in cfg defaults area * fix the fix i applied that was actually not a fix * rm num_batch from hf_eval * improve error msg when checking target_blocks in activation_checkpointing_target (#977) * Torch 2.2 upgrade - Part 1 (#976) * Torch 2.2 - Part 2 (#979) * PyTorch 2.2 - Part 3 (#981) * Remove torch 2.1 from docker build (#982) * Async callback: Don't skip checkpoints, reliably only launch async eval when the checkpoint is ready (#813) * working without sharded checkpointing.. * add more debugs * try this * more debugging * yikes dumb bug * add notes * fixes * remove prints * small updates * fix typo * refactor * fix docstring formatting * fighting with docstrings * try this * add unit tests * point to composer update * values -> items * serialize time * fix merge * nits * warning, small comment update * add error --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Token accuracy metrics (#983) * do not mention 1.13 in readme (#988) Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Patch test, lock mcli version (#990) * Bump gha timeouts (#991) * Fix readme typo (#993) * if condition in tie weights added (#989) * if condition in tie weights added * unit test for tie weights * bump composer version (#995) * Trim examples ahead of time for auto packing (#994) * add oom observer callback (#932) * add oom observer callback * fix format * Change ci/cd to use ci-testing repo * Revert "Change ci/cd to use ci-testing repo" This reverts commit e3f214e71033ed708ff5db224e986da712baa80b. * Use ci-testing repo (#1000) Co-authored-by: Irene Dea * Make CodeEval respect device_eval_batch_size (#956) * Remove try except around imports (#1004) * Deprecate triton, prefix lm, llama attention patch, and text denoising; Make ComposerHFT5 experimental (#1007) * Deprecate features and mark experimental * fix typo --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * add magic filename for sharded state dicts (#1001) * add magic filename for sharded state dicts * Update scripts/train/train.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * oops forgot to push this * no shard if no fsdp * default to full on foundry --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * bump (#1009) * Fix evaluators actually pulling eval metrics (#1006) * fix bug on metrics * lint * lint * add unit test * lint * Build torch 2.2.1 images (#1010) * add 2.2.1 tests (#1011) * Bump min torch pin (#1013) Red button because CI running jobs it doesn't need. Tests passed on main. * Fix extra BOS token in front of response for some tokenizers (#1003) * Bump min composer pin (#1015) * add default for eval interval (#987) Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Add support for olmo (#1016) * Add deeper support for multi-turn chats and loss-generating tokens in finetuning (#985) The main purpose of this PR is to support training on non-terminal responses in multi-round chats. This is achieved by tokenizing at the level of conversation "turns" and exposing some options for what turns are used as training targets (i.e. generate loss). This also adds support for treating prompt tokens as loss-generating. The script for converting a finetuning dataset to streaming has also been updated (with some bug fixes). * Fix profiling packing ratio to explicitly say 1 (#1019) * Bump transformers to 4.38.2 (#1018) * that kwargs (#1020) * Update readme with pytorch 2.2.1 (#1021) * Add code import to train/eval scripts (#1002) * finish (#1022) Co-authored-by: Max Marion * Bump version to 0.6.0 (#1023) * Fix typo in monolithic chkpt callback docs (#1024) * Fix typo in monolithic chkpt callback docs * reorder to match function signature * update pip install link * Change done file location * Create the dest folder * Allow code-quality workflow to be callable (#1026) Reverts part of the change made in https://github.com/mosaicml/llm-foundry/pull/1000/files#diff-4a2765c2cfcbd3804a66aab805cb92ddda74de1730923cc5bf53671d0beccf06L11 * update notebook * update * update notebook * update token_counts * update pip install list * fix * update * fix token counts * Expose validate chat * Expose more * update * expose * add collate * Fix --------- Co-authored-by: Xiaohan Zhang Co-authored-by: xiaohanzhan-db Co-authored-by: Mihir Patel Co-authored-by: Milo Cress Co-authored-by: Nancy Hung Co-authored-by: Jerry Chen Co-authored-by: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Co-authored-by: Irene Dea Co-authored-by: Brian <23239305+b-chu@users.noreply.github.com> Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> Co-authored-by: Anna Co-authored-by: Nicholas Garcia Co-authored-by: Prithviraj Ammanabrolu Co-authored-by: Jane Zhang Co-authored-by: Vincent Chen Co-authored-by: Aaron Gokaslan Co-authored-by: Jeremy D <115047575+bmosaicml@users.noreply.github.com> Co-authored-by: dblalock Co-authored-by: bigning Co-authored-by: Cheng Li Co-authored-by: Sebastián Donoso Bustos Co-authored-by: Saaketh Narayan Co-authored-by: Max Marion Co-authored-by: Megha Agarwal <16129366+megha95@users.noreply.github.com> Co-authored-by: Jose Javier <26491792+josejg@users.noreply.github.com> Co-authored-by: Alex Trott Co-authored-by: Sasha Doubov --- llmfoundry/data/finetuning/__init__.py | 16 +++++++++-- llmfoundry/data/finetuning/tasks.py | 16 +++++++++++ llmfoundry/utils/__init__.py | 4 +-- llmfoundry/utils/validation_utils.py | 37 ++++++++++++++++---------- 4 files changed, 55 insertions(+), 18 deletions(-) diff --git a/llmfoundry/data/finetuning/__init__.py b/llmfoundry/data/finetuning/__init__.py index 9d10a17cfa..b2375ab03b 100644 --- a/llmfoundry/data/finetuning/__init__.py +++ b/llmfoundry/data/finetuning/__init__.py @@ -2,6 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator -from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader, _build_collate_fn +from llmfoundry.data.finetuning.tasks import (_validate_chat_formatted_example, + _validate_prompt_response_formatted_example, + _get_example_type, PromptResponseDict, ChatFormattedDict) -__all__ = ['Seq2SeqFinetuningCollator', 'build_finetuning_dataloader'] +__all__ = [ + 'Seq2SeqFinetuningCollator', + 'build_finetuning_dataloader', + '_build_collate_fn', + '_validate_chat_formatted_example', + '_validate_prompt_response_formatted_example', + '_get_example_type', + 'PromptResponseDict', + 'ChatFormattedDict' +] diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 103ba71215..ef0e32c5c5 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -282,10 +282,15 @@ def _tokenize_chat_formatted_example( } +<<<<<<< HEAD +def _validate_prompt_response_formatted_example(example: PromptResponseDict): + """Validate expected keys.""" +======= def _tokenize_prompt_response_formatted_example( example: PromptResponseDict, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: """Tokenize a formatted example and validate expected keys.""" +>>>>>>> c404dc7ec03897c283a3cae0473a65257b51a2aa example_keys = set(example.keys()) prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS) response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS) @@ -317,6 +322,17 @@ def _tokenize_prompt_response_formatted_example( f'Unable to tokenize example because {response_key} was not a string. {example=}' ) +<<<<<<< HEAD + return prompt, response + +def _tokenize_prompt_response_formatted_example( + example: PromptResponseDict, + tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: + """Tokenize a formatted example and validate expected keys.""" + prompt, response = _validate_prompt_response_formatted_example(example) + +======= +>>>>>>> c404dc7ec03897c283a3cae0473a65257b51a2aa # Note: We default to the tokenizer's add_bos_token and add_eos_token behavior here # (which we do not do for chat-formatted examples). This is because chat examples specifically # go through the tokenizer's `apply_chat_template` method, which handles special tokens, diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index b01414acd0..1234158084 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -24,7 +24,7 @@ pandas_processing_fn, parse_args, plot_hist, token_counts, - token_counts_and_validation) + token_counts_with_collate) # yapf: enable __all__ = [ @@ -44,7 +44,7 @@ 'log_config', 'pop_config', 'create_om_cfg', - 'token_counts_and_validation', + 'token_counts_with_collate', 'token_counts', 'check_HF_datasets', 'is_hf_dataset_path', diff --git a/llmfoundry/utils/validation_utils.py b/llmfoundry/utils/validation_utils.py index 26dae7f0d9..3dfecd442b 100644 --- a/llmfoundry/utils/validation_utils.py +++ b/llmfoundry/utils/validation_utils.py @@ -83,24 +83,33 @@ def create_om_cfg(FT_API_args: Namespace): return cfg, tokenizer -def token_counts_and_validation(FT_API_args): - from llmfoundry.data.finetuning import build_finetuning_dataloader +def token_counts_with_collate(FT_API_args): + from llmfoundry.data.finetuning import build_finetuning_dataloader, _build_collate_fn cfg, tokenizer = create_om_cfg(FT_API_args) + detected_cpu_count = os.cpu_count() or 1 + num_cpus_to_use = max(1, detected_cpu_count) + cfg.num_workers = num_cpus_to_use device_batch_size = 1 dataspec = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) dataloader = dataspec.dataloader - token_counting_func = dataspec.get_num_tokens_in_batch - - total_tokens = [] - for batch in tqdm(dataloader): - n_batch_tokens = token_counting_func(batch) - if n_batch_tokens == 0: - raise ValueError('Empty train sample') - total_tokens.append(n_batch_tokens) - return total_tokens + collate_fn, dataloader_batch_size = _build_collate_fn( + cfg, tokenizer, device_batch_size) + + def mapper(example: dict): + batch = collate_fn([example]) + return get_num_samples_in_batch(batch) + + token_lens = dataloader.dataset.map( + mapper, + batched=False, + num_proc=num_cpus_to_use, + desc='List of Token length', + ) + + return token_lens from typing import Any, Callable, Dict, List, Mapping, Optional, Union, cast @@ -123,7 +132,7 @@ def get_num_samples_in_batch(batch: dict) -> int: # Count number of non padding tokens in batch if 'attention_mask' in batch: - input_ids_tokens = batch['attention_mask'].numel() # int(sum(batch['attention_mask'])) + input_ids_tokens = int(torch.sum(batch['attention_mask']).item()) else: input_ids_tokens = batch['input_ids'].numel() @@ -152,9 +161,9 @@ def token_counts(FT_API_args): dataspec = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) dataloader = dataspec.dataloader - token_lens = 0 + token_lens = [] for b in dataloader: - token_lens += get_num_samples_in_batch(b) + token_lens.append(get_num_samples_in_batch(b)['ntokens']) #token_lens = dataloader.dataset.map( # get_num_samples_in_batch,